1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "llvm/ADT/STLExtras.h" 22 23 using namespace mlir; 24 using namespace mlir::memref; 25 26 /// Materialize a single constant operation from a given attribute value with 27 /// the desired resultant type. 28 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 29 Attribute value, Type type, 30 Location loc) { 31 return builder.create<mlir::ConstantOp>(loc, type, value); 32 } 33 34 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 35 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 36 return llvm::to_vector<4>( 37 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 38 return a.cast<IntegerAttr>().getInt(); 39 })); 40 } 41 42 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 43 /// it is a Value or into `staticVec` if it is an IntegerAttr. 44 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 45 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 46 /// come from an AttrSizedOperandSegments trait. 47 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 48 SmallVectorImpl<Value> &dynamicVec, 49 SmallVectorImpl<int64_t> &staticVec, 50 int64_t sentinel) { 51 if (auto v = ofr.dyn_cast<Value>()) { 52 dynamicVec.push_back(v); 53 staticVec.push_back(sentinel); 54 return; 55 } 56 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 57 staticVec.push_back(apInt.getSExtValue()); 58 } 59 60 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 61 SmallVectorImpl<Value> &dynamicVec, 62 SmallVectorImpl<int64_t> &staticVec, 63 int64_t sentinel) { 64 for (auto ofr : ofrs) 65 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // Common canonicalization pattern support logic 70 //===----------------------------------------------------------------------===// 71 72 /// This is a common class used for patterns of the form 73 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 74 /// into the root operation directly. 75 static LogicalResult foldMemRefCast(Operation *op) { 76 bool folded = false; 77 for (OpOperand &operand : op->getOpOperands()) { 78 auto cast = operand.get().getDefiningOp<CastOp>(); 79 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 80 operand.set(cast.getOperand()); 81 folded = true; 82 } 83 } 84 return success(folded); 85 } 86 87 //===----------------------------------------------------------------------===// 88 // Helpers for GlobalOp 89 //===----------------------------------------------------------------------===// 90 91 static Type getTensorTypeFromMemRefType(Type type) { 92 if (auto memref = type.dyn_cast<MemRefType>()) 93 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 94 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 95 return UnrankedTensorType::get(memref.getElementType()); 96 return NoneType::get(type.getContext()); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // AllocOp / AllocaOp 101 //===----------------------------------------------------------------------===// 102 103 template <typename AllocLikeOp> 104 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 105 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 106 "applies to only alloc or alloca"); 107 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 108 if (!memRefType) 109 return op.emitOpError("result must be a memref"); 110 111 if (static_cast<int64_t>(op.dynamicSizes().size()) != 112 memRefType.getNumDynamicDims()) 113 return op.emitOpError("dimension operand count does not equal memref " 114 "dynamic dimension count"); 115 116 unsigned numSymbols = 0; 117 if (!memRefType.getAffineMaps().empty()) 118 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 119 if (op.symbolOperands().size() != numSymbols) 120 return op.emitOpError( 121 "symbol operand count does not equal memref symbol count"); 122 123 return success(); 124 } 125 126 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 127 128 static LogicalResult verify(AllocaOp op) { 129 // An alloca op needs to have an ancestor with an allocation scope trait. 130 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 131 return op.emitOpError( 132 "requires an ancestor op with AutomaticAllocationScope trait"); 133 134 return verifyAllocLikeOp(op); 135 } 136 137 namespace { 138 /// Fold constant dimensions into an alloc like operation. 139 template <typename AllocLikeOp> 140 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 141 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 142 143 LogicalResult matchAndRewrite(AllocLikeOp alloc, 144 PatternRewriter &rewriter) const override { 145 // Check to see if any dimensions operands are constants. If so, we can 146 // substitute and drop them. 147 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 148 return matchPattern(operand, matchConstantIndex()); 149 })) 150 return failure(); 151 152 auto memrefType = alloc.getType(); 153 154 // Ok, we have one or more constant operands. Collect the non-constant ones 155 // and keep track of the resultant memref type to build. 156 SmallVector<int64_t, 4> newShapeConstants; 157 newShapeConstants.reserve(memrefType.getRank()); 158 SmallVector<Value, 4> newOperands; 159 160 unsigned dynamicDimPos = 0; 161 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 162 int64_t dimSize = memrefType.getDimSize(dim); 163 // If this is already static dimension, keep it. 164 if (dimSize != -1) { 165 newShapeConstants.push_back(dimSize); 166 continue; 167 } 168 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 169 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 170 // Dynamic shape dimension will be folded. 171 newShapeConstants.push_back(constantIndexOp.getValue()); 172 } else { 173 // Dynamic shape dimension not folded; copy operand from old memref. 174 newShapeConstants.push_back(-1); 175 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 176 } 177 dynamicDimPos++; 178 } 179 180 // Create new memref type (which will have fewer dynamic dimensions). 181 MemRefType newMemRefType = 182 MemRefType::Builder(memrefType).setShape(newShapeConstants); 183 assert(static_cast<int64_t>(newOperands.size()) == 184 newMemRefType.getNumDynamicDims()); 185 186 // Create and insert the alloc op for the new memref. 187 auto newAlloc = rewriter.create<AllocLikeOp>( 188 alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); 189 // Insert a cast so we have the same type as the old alloc. 190 auto resultCast = 191 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 192 193 rewriter.replaceOp(alloc, {resultCast}); 194 return success(); 195 } 196 }; 197 198 /// Fold alloc operations with no users or only store and dealloc uses. 199 template <typename T> 200 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 201 using OpRewritePattern<T>::OpRewritePattern; 202 203 LogicalResult matchAndRewrite(T alloc, 204 PatternRewriter &rewriter) const override { 205 if (llvm::any_of(alloc->getUsers(), [](Operation *op) { 206 return !isa<StoreOp, DeallocOp>(op); 207 })) 208 return failure(); 209 210 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 211 rewriter.eraseOp(user); 212 213 rewriter.eraseOp(alloc); 214 return success(); 215 } 216 }; 217 } // end anonymous namespace. 218 219 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 220 MLIRContext *context) { 221 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 222 } 223 224 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 225 MLIRContext *context) { 226 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 227 context); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // AssumeAlignmentOp 232 //===----------------------------------------------------------------------===// 233 234 static LogicalResult verify(AssumeAlignmentOp op) { 235 unsigned alignment = op.alignment(); 236 if (!llvm::isPowerOf2_32(alignment)) 237 return op.emitOpError("alignment must be power of 2"); 238 return success(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // BufferCastOp 243 //===----------------------------------------------------------------------===// 244 245 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 246 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 247 if (tensorLoad.memref().getType() == getType()) 248 return tensorLoad.memref(); 249 return {}; 250 } 251 252 namespace { 253 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 254 struct BufferCast : public OpRewritePattern<BufferCastOp> { 255 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 256 257 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 258 PatternRewriter &rewriter) const final { 259 auto tensorCastOperand = 260 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 261 if (!tensorCastOperand) 262 return failure(); 263 auto srcTensorType = 264 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 265 if (!srcTensorType) 266 return failure(); 267 auto memrefType = MemRefType::get(srcTensorType.getShape(), 268 srcTensorType.getElementType()); 269 Value memref = rewriter.create<BufferCastOp>( 270 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 271 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 272 memref); 273 return success(); 274 } 275 }; 276 277 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 278 /// type mismatches prevent `BufferCastOp::fold` to kick in. 279 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 280 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 281 282 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 283 PatternRewriter &rewriter) const final { 284 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 285 // Bail unless we have a tensor_load + memref.buffer_cast with different 286 // types. `BufferCastOp::fold` handles the same type case. 287 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 288 return failure(); 289 // If types are not cast-compatible, bail. 290 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 291 bufferCast.getType())) 292 return failure(); 293 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 294 tensorLoad.memref()); 295 return success(); 296 } 297 }; 298 299 } // namespace 300 301 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 302 MLIRContext *context) { 303 results.add<BufferCast, TensorLoadToMemRef>(context); 304 } 305 306 //===----------------------------------------------------------------------===// 307 // CastOp 308 //===----------------------------------------------------------------------===// 309 310 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 311 /// source memref. This is useful to to fold a memref.cast into a consuming op 312 /// and implement canonicalization patterns for ops in different dialects that 313 /// may consume the results of memref.cast operations. Such foldable memref.cast 314 /// operations are typically inserted as `view` and `subview` ops are 315 /// canonicalized, to preserve the type compatibility of their uses. 316 /// 317 /// Returns true when all conditions are met: 318 /// 1. source and result are ranked memrefs with strided semantics and same 319 /// element type and rank. 320 /// 2. each of the source's size, offset or stride has more static information 321 /// than the corresponding result's size, offset or stride. 322 /// 323 /// Example 1: 324 /// ```mlir 325 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 326 /// %2 = consumer %1 ... : memref<?x?xf32> ... 327 /// ``` 328 /// 329 /// may fold into: 330 /// 331 /// ```mlir 332 /// %2 = consumer %0 ... : memref<8x16xf32> ... 333 /// ``` 334 /// 335 /// Example 2: 336 /// ``` 337 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 338 /// to memref<?x?xf32> 339 /// consumer %1 : memref<?x?xf32> ... 340 /// ``` 341 /// 342 /// may fold into: 343 /// 344 /// ``` 345 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 346 /// ``` 347 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 348 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 349 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 350 351 // Requires ranked MemRefType. 352 if (!sourceType || !resultType) 353 return false; 354 355 // Requires same elemental type. 356 if (sourceType.getElementType() != resultType.getElementType()) 357 return false; 358 359 // Requires same rank. 360 if (sourceType.getRank() != resultType.getRank()) 361 return false; 362 363 // Only fold casts between strided memref forms. 364 int64_t sourceOffset, resultOffset; 365 SmallVector<int64_t, 4> sourceStrides, resultStrides; 366 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 367 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 368 return false; 369 370 // If cast is towards more static sizes along any dimension, don't fold. 371 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 372 auto ss = std::get<0>(it), st = std::get<1>(it); 373 if (ss != st) 374 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 375 return false; 376 } 377 378 // If cast is towards more static offset along any dimension, don't fold. 379 if (sourceOffset != resultOffset) 380 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 381 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 382 return false; 383 384 // If cast is towards more static strides along any dimension, don't fold. 385 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 386 auto ss = std::get<0>(it), st = std::get<1>(it); 387 if (ss != st) 388 if (MemRefType::isDynamicStrideOrOffset(ss) && 389 !MemRefType::isDynamicStrideOrOffset(st)) 390 return false; 391 } 392 393 return true; 394 } 395 396 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 397 if (inputs.size() != 1 || outputs.size() != 1) 398 return false; 399 Type a = inputs.front(), b = outputs.front(); 400 auto aT = a.dyn_cast<MemRefType>(); 401 auto bT = b.dyn_cast<MemRefType>(); 402 403 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 404 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 405 406 if (aT && bT) { 407 if (aT.getElementType() != bT.getElementType()) 408 return false; 409 if (aT.getAffineMaps() != bT.getAffineMaps()) { 410 int64_t aOffset, bOffset; 411 SmallVector<int64_t, 4> aStrides, bStrides; 412 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 413 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 414 aStrides.size() != bStrides.size()) 415 return false; 416 417 // Strides along a dimension/offset are compatible if the value in the 418 // source memref is static and the value in the target memref is the 419 // same. They are also compatible if either one is dynamic (see 420 // description of MemRefCastOp for details). 421 auto checkCompatible = [](int64_t a, int64_t b) { 422 return (a == MemRefType::getDynamicStrideOrOffset() || 423 b == MemRefType::getDynamicStrideOrOffset() || a == b); 424 }; 425 if (!checkCompatible(aOffset, bOffset)) 426 return false; 427 for (auto aStride : enumerate(aStrides)) 428 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 429 return false; 430 } 431 if (aT.getMemorySpace() != bT.getMemorySpace()) 432 return false; 433 434 // They must have the same rank, and any specified dimensions must match. 435 if (aT.getRank() != bT.getRank()) 436 return false; 437 438 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 439 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 440 if (aDim != -1 && bDim != -1 && aDim != bDim) 441 return false; 442 } 443 return true; 444 } else { 445 if (!aT && !uaT) 446 return false; 447 if (!bT && !ubT) 448 return false; 449 // Unranked to unranked casting is unsupported 450 if (uaT && ubT) 451 return false; 452 453 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 454 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 455 if (aEltType != bEltType) 456 return false; 457 458 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 459 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 460 if (aMemSpace != bMemSpace) 461 return false; 462 463 return true; 464 } 465 466 return false; 467 } 468 469 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 470 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 471 } 472 473 //===----------------------------------------------------------------------===// 474 // CloneOp 475 //===----------------------------------------------------------------------===// 476 477 static LogicalResult verify(CloneOp op) { return success(); } 478 479 void CloneOp::getEffects( 480 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 481 &effects) { 482 effects.emplace_back(MemoryEffects::Read::get(), input(), 483 SideEffects::DefaultResource::get()); 484 effects.emplace_back(MemoryEffects::Write::get(), output(), 485 SideEffects::DefaultResource::get()); 486 } 487 488 namespace { 489 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 490 /// by other Dealloc operations. 491 struct SimplifyClones : public OpRewritePattern<CloneOp> { 492 using OpRewritePattern<CloneOp>::OpRewritePattern; 493 494 LogicalResult matchAndRewrite(CloneOp cloneOp, 495 PatternRewriter &rewriter) const override { 496 if (cloneOp.use_empty()) { 497 rewriter.eraseOp(cloneOp); 498 return success(); 499 } 500 501 Value source = cloneOp.input(); 502 503 // Removes the clone operation and the corresponding dealloc and alloc 504 // operation (if any). 505 auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc, 506 Operation *alloc) { 507 if (!sourceOp || !dealloc || !alloc || 508 alloc->getBlock() != dealloc->getBlock()) 509 return false; 510 rewriter.replaceOp(cloneOp, source); 511 rewriter.eraseOp(dealloc); 512 return true; 513 }; 514 515 // Removes unnecessary clones that are derived from the result of the clone 516 // op. 517 Operation *deallocOp = findDealloc(cloneOp.output()); 518 Operation *sourceOp = source.getDefiningOp(); 519 if (tryRemoveClone(sourceOp, deallocOp, sourceOp)) 520 return success(); 521 522 // Removes unnecessary clones that are derived from the source of the clone 523 // op. 524 deallocOp = findDealloc(source); 525 if (tryRemoveClone(sourceOp, deallocOp, cloneOp)) 526 return success(); 527 528 return failure(); 529 } 530 }; 531 532 } // end anonymous namespace. 533 534 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 535 MLIRContext *context) { 536 results.insert<SimplifyClones>(context); 537 } 538 539 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 540 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // DeallocOp 545 //===----------------------------------------------------------------------===// 546 547 static LogicalResult verify(DeallocOp op) { 548 if (!op.memref().getType().isa<MemRefType>()) 549 return op.emitOpError("operand must be a memref"); 550 return success(); 551 } 552 553 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 554 SmallVectorImpl<OpFoldResult> &results) { 555 /// dealloc(memrefcast) -> dealloc 556 return foldMemRefCast(*this); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // DimOp 561 //===----------------------------------------------------------------------===// 562 563 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 564 int64_t index) { 565 auto loc = result.location; 566 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 567 build(builder, result, memref, indexValue); 568 } 569 570 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 571 Value index) { 572 auto indexTy = builder.getIndexType(); 573 build(builder, result, indexTy, memref, index); 574 } 575 576 Optional<int64_t> DimOp::getConstantIndex() { 577 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 578 return constantOp.getValue().cast<IntegerAttr>().getInt(); 579 return {}; 580 } 581 582 static LogicalResult verify(DimOp op) { 583 // Assume unknown index to be in range. 584 Optional<int64_t> index = op.getConstantIndex(); 585 if (!index.hasValue()) 586 return success(); 587 588 // Check that constant index is not knowingly out of range. 589 auto type = op.memrefOrTensor().getType(); 590 if (auto memrefType = type.dyn_cast<MemRefType>()) { 591 if (index.getValue() >= memrefType.getRank()) 592 return op.emitOpError("index is out of range"); 593 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 594 if (index.getValue() >= tensorType.getRank()) 595 return op.emitOpError("index is out of range"); 596 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 597 // Assume index to be in range. 598 } else { 599 llvm_unreachable("expected operand with memref type"); 600 } 601 return success(); 602 } 603 604 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 605 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 606 607 // All forms of folding require a known index. 608 if (!index) 609 return {}; 610 611 auto argTy = memrefOrTensor().getType(); 612 // Fold if the shape extent along the given index is known. 613 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 614 // Folding for unranked types (UnrankedMemRefType) is not supported. 615 if (!shapedTy.hasRank()) 616 return {}; 617 if (!shapedTy.isDynamicDim(index.getInt())) { 618 Builder builder(getContext()); 619 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 620 } 621 } 622 623 Operation *definingOp = memrefOrTensor().getDefiningOp(); 624 625 // dim(memref.tensor_load(memref)) -> dim(memref) 626 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 627 setOperand(0, tensorLoadOp.memref()); 628 return getResult(); 629 } 630 631 // Fold dim to the operand of tensor.generate. 632 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 633 auto resultType = 634 fromElements.getResult().getType().cast<RankedTensorType>(); 635 // The case where the type encodes the size of the dimension is handled 636 // above. 637 assert(resultType.getShape()[index.getInt()] == 638 RankedTensorType::kDynamicSize); 639 640 // Find the operand of the fromElements that corresponds to this index. 641 auto dynExtents = fromElements.dynamicExtents().begin(); 642 for (auto dim : resultType.getShape().take_front(index.getInt())) 643 if (dim == RankedTensorType::kDynamicSize) 644 dynExtents++; 645 646 return Value{*dynExtents}; 647 } 648 649 // The size at the given index is now known to be a dynamic size. 650 unsigned unsignedIndex = index.getValue().getZExtValue(); 651 652 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 653 assert(subtensor.isDynamicSize(unsignedIndex) && 654 "Expected dynamic subtensor size"); 655 return subtensor.getDynamicSize(unsignedIndex); 656 } 657 658 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 659 auto memrefType = argTy.dyn_cast<MemRefType>(); 660 if (!memrefType) 661 return {}; 662 663 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 664 return *(alloc.getDynamicSizes().begin() + 665 memrefType.getDynamicDimIndex(unsignedIndex)); 666 667 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 668 return *(alloca.getDynamicSizes().begin() + 669 memrefType.getDynamicDimIndex(unsignedIndex)); 670 671 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 672 return *(view.getDynamicSizes().begin() + 673 memrefType.getDynamicDimIndex(unsignedIndex)); 674 675 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 676 assert(subview.isDynamicSize(unsignedIndex) && 677 "Expected dynamic subview size"); 678 return subview.getDynamicSize(unsignedIndex); 679 } 680 681 // dim(memrefcast) -> dim 682 if (succeeded(foldMemRefCast(*this))) 683 return getResult(); 684 685 return {}; 686 } 687 688 namespace { 689 /// Fold dim of a memref reshape operation to a load into the reshape's shape 690 /// operand. 691 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 692 using OpRewritePattern<DimOp>::OpRewritePattern; 693 694 LogicalResult matchAndRewrite(DimOp dim, 695 PatternRewriter &rewriter) const override { 696 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 697 698 if (!reshape) 699 return failure(); 700 701 // Place the load directly after the reshape to ensure that the shape memref 702 // was not mutated. 703 rewriter.setInsertionPointAfter(reshape); 704 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(), 705 llvm::makeArrayRef({dim.index()})); 706 return success(); 707 } 708 }; 709 710 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 711 template <typename CastOpTy> 712 struct DimOfCastOp : public OpRewritePattern<DimOp> { 713 using OpRewritePattern<DimOp>::OpRewritePattern; 714 715 LogicalResult matchAndRewrite(DimOp dimOp, 716 PatternRewriter &rewriter) const override { 717 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 718 if (!castOp) 719 return failure(); 720 Value newSource = castOp.getOperand(); 721 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 722 return success(); 723 } 724 }; 725 726 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th 727 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. 728 /// TODO(ravishankarm): This is better put as a interface utility method 729 /// somewhere, but that would imply the interface will depend on the `tensor` 730 /// dialect. Ideally maybe a utility method in the `tensor` dialect. 731 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, 732 int64_t dimIndex) { 733 unsigned resultNumber = result.getResultNumber(); 734 auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); 735 Location loc = result.getOwner()->getLoc(); 736 if (!shapedTypeOp) 737 return nullptr; 738 739 // The interface exposes two methods, one that returns the shape of all the 740 // results as `Value` and other that returns the shape as a list of 741 // `SmallVector<Value>`. The former takes precedence over the latter. So first 742 // check if the op implements the first interface method or the second, and 743 // get the value to use appropriately. 744 SmallVector<Value> reifiedResultShapes; 745 if (succeeded( 746 shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) { 747 if (reifiedResultShapes.size() <= resultNumber) 748 return nullptr; 749 Value resultShape = reifiedResultShapes[resultNumber]; 750 auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); 751 if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) 752 return nullptr; 753 return builder.create<tensor::ExtractOp>( 754 loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); 755 } 756 757 SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; 758 if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( 759 builder, reifiedResultShapesPerDim))) 760 return nullptr; 761 if (reifiedResultShapesPerDim.size() <= resultNumber || 762 reifiedResultShapesPerDim[resultNumber].size() != 763 static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) 764 return nullptr; 765 OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; 766 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) 767 return builder.createOrFold<ConstantIndexOp>( 768 loc, attr.cast<IntegerAttr>().getInt()); 769 return valueOrAttr.get<Value>(); 770 } 771 772 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 773 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> { 774 using OpRewritePattern<DimOp>::OpRewritePattern; 775 776 LogicalResult matchAndRewrite(DimOp dimOp, 777 PatternRewriter &rewriter) const override { 778 OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); 779 if (!dimValue) 780 return failure(); 781 auto shapedTypeOp = 782 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 783 if (!shapedTypeOp) 784 return failure(); 785 786 Optional<int64_t> dimIndex = dimOp.getConstantIndex(); 787 if (!dimIndex) 788 return failure(); 789 Value replacement = 790 getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); 791 if (!replacement) 792 return failure(); 793 rewriter.replaceOp(dimOp, replacement); 794 return success(); 795 } 796 }; 797 } // end anonymous namespace. 798 799 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 800 MLIRContext *context) { 801 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 802 DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context); 803 } 804 805 // --------------------------------------------------------------------------- 806 // DmaStartOp 807 // --------------------------------------------------------------------------- 808 809 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 810 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 811 ValueRange destIndices, Value numElements, 812 Value tagMemRef, ValueRange tagIndices, Value stride, 813 Value elementsPerStride) { 814 result.addOperands(srcMemRef); 815 result.addOperands(srcIndices); 816 result.addOperands(destMemRef); 817 result.addOperands(destIndices); 818 result.addOperands({numElements, tagMemRef}); 819 result.addOperands(tagIndices); 820 if (stride) 821 result.addOperands({stride, elementsPerStride}); 822 } 823 824 void DmaStartOp::print(OpAsmPrinter &p) { 825 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 826 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 827 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 828 << ']'; 829 if (isStrided()) 830 p << ", " << getStride() << ", " << getNumElementsPerStride(); 831 832 p.printOptionalAttrDict((*this)->getAttrs()); 833 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 834 << ", " << getTagMemRef().getType(); 835 } 836 837 // Parse DmaStartOp. 838 // Ex: 839 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 840 // %tag[%index], %stride, %num_elt_per_stride : 841 // : memref<3076 x f32, 0>, 842 // memref<1024 x f32, 2>, 843 // memref<1 x i32> 844 // 845 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 846 OpAsmParser::OperandType srcMemRefInfo; 847 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 848 OpAsmParser::OperandType dstMemRefInfo; 849 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 850 OpAsmParser::OperandType numElementsInfo; 851 OpAsmParser::OperandType tagMemrefInfo; 852 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 853 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 854 855 SmallVector<Type, 3> types; 856 auto indexType = parser.getBuilder().getIndexType(); 857 858 // Parse and resolve the following list of operands: 859 // *) source memref followed by its indices (in square brackets). 860 // *) destination memref followed by its indices (in square brackets). 861 // *) dma size in KiB. 862 if (parser.parseOperand(srcMemRefInfo) || 863 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 864 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 865 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 866 parser.parseComma() || parser.parseOperand(numElementsInfo) || 867 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 868 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 869 return failure(); 870 871 // Parse optional stride and elements per stride. 872 if (parser.parseTrailingOperandList(strideInfo)) 873 return failure(); 874 875 bool isStrided = strideInfo.size() == 2; 876 if (!strideInfo.empty() && !isStrided) { 877 return parser.emitError(parser.getNameLoc(), 878 "expected two stride related operands"); 879 } 880 881 if (parser.parseColonTypeList(types)) 882 return failure(); 883 if (types.size() != 3) 884 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 885 886 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 887 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 888 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 889 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 890 // size should be an index. 891 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 892 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 893 // tag indices should be index. 894 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 895 return failure(); 896 897 if (isStrided) { 898 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 899 return failure(); 900 } 901 902 return success(); 903 } 904 905 LogicalResult DmaStartOp::verify() { 906 unsigned numOperands = getNumOperands(); 907 908 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 909 // the number of elements. 910 if (numOperands < 4) 911 return emitOpError("expected at least 4 operands"); 912 913 // Check types of operands. The order of these calls is important: the later 914 // calls rely on some type properties to compute the operand position. 915 // 1. Source memref. 916 if (!getSrcMemRef().getType().isa<MemRefType>()) 917 return emitOpError("expected source to be of memref type"); 918 if (numOperands < getSrcMemRefRank() + 4) 919 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 920 << " operands"; 921 if (!getSrcIndices().empty() && 922 !llvm::all_of(getSrcIndices().getTypes(), 923 [](Type t) { return t.isIndex(); })) 924 return emitOpError("expected source indices to be of index type"); 925 926 // 2. Destination memref. 927 if (!getDstMemRef().getType().isa<MemRefType>()) 928 return emitOpError("expected destination to be of memref type"); 929 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 930 if (numOperands < numExpectedOperands) 931 return emitOpError() << "expected at least " << numExpectedOperands 932 << " operands"; 933 if (!getDstIndices().empty() && 934 !llvm::all_of(getDstIndices().getTypes(), 935 [](Type t) { return t.isIndex(); })) 936 return emitOpError("expected destination indices to be of index type"); 937 938 // 3. Number of elements. 939 if (!getNumElements().getType().isIndex()) 940 return emitOpError("expected num elements to be of index type"); 941 942 // 4. Tag memref. 943 if (!getTagMemRef().getType().isa<MemRefType>()) 944 return emitOpError("expected tag to be of memref type"); 945 numExpectedOperands += getTagMemRefRank(); 946 if (numOperands < numExpectedOperands) 947 return emitOpError() << "expected at least " << numExpectedOperands 948 << " operands"; 949 if (!getTagIndices().empty() && 950 !llvm::all_of(getTagIndices().getTypes(), 951 [](Type t) { return t.isIndex(); })) 952 return emitOpError("expected tag indices to be of index type"); 953 954 // DMAs from different memory spaces supported. 955 if (getSrcMemorySpace() == getDstMemorySpace()) 956 return emitOpError("DMA should be between different memory spaces"); 957 958 // Optional stride-related operands must be either both present or both 959 // absent. 960 if (numOperands != numExpectedOperands && 961 numOperands != numExpectedOperands + 2) 962 return emitOpError("incorrect number of operands"); 963 964 // 5. Strides. 965 if (isStrided()) { 966 if (!getStride().getType().isIndex() || 967 !getNumElementsPerStride().getType().isIndex()) 968 return emitOpError( 969 "expected stride and num elements per stride to be of type index"); 970 } 971 972 return success(); 973 } 974 975 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 976 SmallVectorImpl<OpFoldResult> &results) { 977 /// dma_start(memrefcast) -> dma_start 978 return foldMemRefCast(*this); 979 } 980 981 // --------------------------------------------------------------------------- 982 // DmaWaitOp 983 // --------------------------------------------------------------------------- 984 985 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 986 Value tagMemRef, ValueRange tagIndices, 987 Value numElements) { 988 result.addOperands(tagMemRef); 989 result.addOperands(tagIndices); 990 result.addOperands(numElements); 991 } 992 993 void DmaWaitOp::print(OpAsmPrinter &p) { 994 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 995 << "], " << getNumElements(); 996 p.printOptionalAttrDict((*this)->getAttrs()); 997 p << " : " << getTagMemRef().getType(); 998 } 999 1000 // Parse DmaWaitOp. 1001 // Eg: 1002 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 1003 // 1004 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 1005 OpAsmParser::OperandType tagMemrefInfo; 1006 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 1007 Type type; 1008 auto indexType = parser.getBuilder().getIndexType(); 1009 OpAsmParser::OperandType numElementsInfo; 1010 1011 // Parse tag memref, its indices, and dma size. 1012 if (parser.parseOperand(tagMemrefInfo) || 1013 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1014 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1015 parser.parseColonType(type) || 1016 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 1017 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 1018 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1019 return failure(); 1020 1021 return success(); 1022 } 1023 1024 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1025 SmallVectorImpl<OpFoldResult> &results) { 1026 /// dma_wait(memrefcast) -> dma_wait 1027 return foldMemRefCast(*this); 1028 } 1029 1030 LogicalResult DmaWaitOp::verify() { 1031 // Mandatory non-variadic operands are tag and the number of elements. 1032 if (getNumOperands() < 2) 1033 return emitOpError() << "expected at least 2 operands"; 1034 1035 // Check types of operands. The order of these calls is important: the later 1036 // calls rely on some type properties to compute the operand position. 1037 if (!getTagMemRef().getType().isa<MemRefType>()) 1038 return emitOpError() << "expected tag to be of memref type"; 1039 1040 if (getNumOperands() != 2 + getTagMemRefRank()) 1041 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1042 << " operands"; 1043 1044 if (!getTagIndices().empty() && 1045 !llvm::all_of(getTagIndices().getTypes(), 1046 [](Type t) { return t.isIndex(); })) 1047 return emitOpError() << "expected tag indices to be of index type"; 1048 1049 if (!getNumElements().getType().isIndex()) 1050 return emitOpError() 1051 << "expected the number of elements to be of index type"; 1052 1053 return success(); 1054 } 1055 1056 //===----------------------------------------------------------------------===// 1057 // GlobalOp 1058 //===----------------------------------------------------------------------===// 1059 1060 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1061 TypeAttr type, 1062 Attribute initialValue) { 1063 p << type; 1064 if (!op.isExternal()) { 1065 p << " = "; 1066 if (op.isUninitialized()) 1067 p << "uninitialized"; 1068 else 1069 p.printAttributeWithoutType(initialValue); 1070 } 1071 } 1072 1073 static ParseResult 1074 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1075 Attribute &initialValue) { 1076 Type type; 1077 if (parser.parseType(type)) 1078 return failure(); 1079 1080 auto memrefType = type.dyn_cast<MemRefType>(); 1081 if (!memrefType || !memrefType.hasStaticShape()) 1082 return parser.emitError(parser.getNameLoc()) 1083 << "type should be static shaped memref, but got " << type; 1084 typeAttr = TypeAttr::get(type); 1085 1086 if (parser.parseOptionalEqual()) 1087 return success(); 1088 1089 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1090 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1091 return success(); 1092 } 1093 1094 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1095 if (parser.parseAttribute(initialValue, tensorType)) 1096 return failure(); 1097 if (!initialValue.isa<ElementsAttr>()) 1098 return parser.emitError(parser.getNameLoc()) 1099 << "initial value should be a unit or elements attribute"; 1100 return success(); 1101 } 1102 1103 static LogicalResult verify(GlobalOp op) { 1104 auto memrefType = op.type().dyn_cast<MemRefType>(); 1105 if (!memrefType || !memrefType.hasStaticShape()) 1106 return op.emitOpError("type should be static shaped memref, but got ") 1107 << op.type(); 1108 1109 // Verify that the initial value, if present, is either a unit attribute or 1110 // an elements attribute. 1111 if (op.initial_value().hasValue()) { 1112 Attribute initValue = op.initial_value().getValue(); 1113 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1114 return op.emitOpError("initial value should be a unit or elements " 1115 "attribute, but got ") 1116 << initValue; 1117 1118 // Check that the type of the initial value is compatible with the type of 1119 // the global variable. 1120 if (initValue.isa<ElementsAttr>()) { 1121 Type initType = initValue.getType(); 1122 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1123 if (initType != tensorType) 1124 return op.emitOpError("initial value expected to be of type ") 1125 << tensorType << ", but was of type " << initType; 1126 } 1127 } 1128 1129 // TODO: verify visibility for declarations. 1130 return success(); 1131 } 1132 1133 //===----------------------------------------------------------------------===// 1134 // GetGlobalOp 1135 //===----------------------------------------------------------------------===// 1136 1137 LogicalResult 1138 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1139 // Verify that the result type is same as the type of the referenced 1140 // memref.global op. 1141 auto global = 1142 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1143 if (!global) 1144 return emitOpError("'") 1145 << name() << "' does not reference a valid global memref"; 1146 1147 Type resultType = result().getType(); 1148 if (global.type() != resultType) 1149 return emitOpError("result type ") 1150 << resultType << " does not match type " << global.type() 1151 << " of the global memref @" << name(); 1152 return success(); 1153 } 1154 1155 //===----------------------------------------------------------------------===// 1156 // LoadOp 1157 //===----------------------------------------------------------------------===// 1158 1159 static LogicalResult verify(LoadOp op) { 1160 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1161 return op.emitOpError("incorrect number of indices for load"); 1162 return success(); 1163 } 1164 1165 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1166 /// load(memrefcast) -> load 1167 if (succeeded(foldMemRefCast(*this))) 1168 return getResult(); 1169 return OpFoldResult(); 1170 } 1171 1172 namespace { 1173 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1174 /// corresponding tensor. 1175 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1176 using OpRewritePattern<LoadOp>::OpRewritePattern; 1177 1178 LogicalResult matchAndRewrite(LoadOp load, 1179 PatternRewriter &rewriter) const override { 1180 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1181 if (!buffercast) 1182 return failure(); 1183 1184 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1185 load.indices()); 1186 return success(); 1187 } 1188 }; 1189 } // end anonymous namespace. 1190 1191 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1192 MLIRContext *context) { 1193 results.add<LoadOfBufferCast>(context); 1194 } 1195 1196 //===----------------------------------------------------------------------===// 1197 // PrefetchOp 1198 //===----------------------------------------------------------------------===// 1199 1200 static void print(OpAsmPrinter &p, PrefetchOp op) { 1201 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1202 p.printOperands(op.indices()); 1203 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1204 p << ", locality<" << op.localityHint(); 1205 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1206 p.printOptionalAttrDict( 1207 op->getAttrs(), 1208 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1209 p << " : " << op.getMemRefType(); 1210 } 1211 1212 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1213 OperationState &result) { 1214 OpAsmParser::OperandType memrefInfo; 1215 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1216 IntegerAttr localityHint; 1217 MemRefType type; 1218 StringRef readOrWrite, cacheType; 1219 1220 auto indexTy = parser.getBuilder().getIndexType(); 1221 auto i32Type = parser.getBuilder().getIntegerType(32); 1222 if (parser.parseOperand(memrefInfo) || 1223 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1224 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1225 parser.parseComma() || parser.parseKeyword("locality") || 1226 parser.parseLess() || 1227 parser.parseAttribute(localityHint, i32Type, "localityHint", 1228 result.attributes) || 1229 parser.parseGreater() || parser.parseComma() || 1230 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1231 parser.resolveOperand(memrefInfo, type, result.operands) || 1232 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1233 return failure(); 1234 1235 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1236 return parser.emitError(parser.getNameLoc(), 1237 "rw specifier has to be 'read' or 'write'"); 1238 result.addAttribute( 1239 PrefetchOp::getIsWriteAttrName(), 1240 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1241 1242 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1243 return parser.emitError(parser.getNameLoc(), 1244 "cache type has to be 'data' or 'instr'"); 1245 1246 result.addAttribute( 1247 PrefetchOp::getIsDataCacheAttrName(), 1248 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1249 1250 return success(); 1251 } 1252 1253 static LogicalResult verify(PrefetchOp op) { 1254 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1255 return op.emitOpError("too few indices"); 1256 1257 return success(); 1258 } 1259 1260 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1261 SmallVectorImpl<OpFoldResult> &results) { 1262 // prefetch(memrefcast) -> prefetch 1263 return foldMemRefCast(*this); 1264 } 1265 1266 //===----------------------------------------------------------------------===// 1267 // ReinterpretCastOp 1268 //===----------------------------------------------------------------------===// 1269 1270 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1271 /// `staticSizes` and `staticStrides` are automatically filled with 1272 /// source-memref-rank sentinel values that encode dynamic entries. 1273 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1274 MemRefType resultType, Value source, 1275 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1276 ArrayRef<OpFoldResult> strides, 1277 ArrayRef<NamedAttribute> attrs) { 1278 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1279 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1280 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1281 ShapedType::kDynamicStrideOrOffset); 1282 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1283 ShapedType::kDynamicSize); 1284 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1285 ShapedType::kDynamicStrideOrOffset); 1286 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1287 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1288 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1289 result.addAttributes(attrs); 1290 } 1291 1292 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1293 MemRefType resultType, Value source, 1294 int64_t offset, ArrayRef<int64_t> sizes, 1295 ArrayRef<int64_t> strides, 1296 ArrayRef<NamedAttribute> attrs) { 1297 SmallVector<OpFoldResult> sizeValues = 1298 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1299 return b.getI64IntegerAttr(v); 1300 })); 1301 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1302 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1303 return b.getI64IntegerAttr(v); 1304 })); 1305 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1306 strideValues, attrs); 1307 } 1308 1309 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1310 MemRefType resultType, Value source, Value offset, 1311 ValueRange sizes, ValueRange strides, 1312 ArrayRef<NamedAttribute> attrs) { 1313 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1314 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1315 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1316 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1317 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1318 } 1319 1320 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1321 // completed automatically, like we have for subview and subtensor. 1322 static LogicalResult verify(ReinterpretCastOp op) { 1323 // The source and result memrefs should be in the same memory space. 1324 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1325 auto resultType = op.getType().cast<MemRefType>(); 1326 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1327 return op.emitError("different memory spaces specified for source type ") 1328 << srcType << " and result memref type " << resultType; 1329 if (srcType.getElementType() != resultType.getElementType()) 1330 return op.emitError("different element types specified for source type ") 1331 << srcType << " and result memref type " << resultType; 1332 1333 // Match sizes in result memref type and in static_sizes attribute. 1334 for (auto &en : 1335 llvm::enumerate(llvm::zip(resultType.getShape(), 1336 extractFromI64ArrayAttr(op.static_sizes())))) { 1337 int64_t resultSize = std::get<0>(en.value()); 1338 int64_t expectedSize = std::get<1>(en.value()); 1339 if (resultSize != expectedSize) 1340 return op.emitError("expected result type with size = ") 1341 << expectedSize << " instead of " << resultSize 1342 << " in dim = " << en.index(); 1343 } 1344 1345 // Match offset and strides in static_offset and static_strides attributes if 1346 // result memref type has an affine map specified. 1347 if (!resultType.getAffineMaps().empty()) { 1348 int64_t resultOffset; 1349 SmallVector<int64_t, 4> resultStrides; 1350 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1351 return failure(); 1352 1353 // Match offset in result memref type and in static_offsets attribute. 1354 int64_t expectedOffset = 1355 extractFromI64ArrayAttr(op.static_offsets()).front(); 1356 if (resultOffset != expectedOffset) 1357 return op.emitError("expected result type with offset = ") 1358 << resultOffset << " instead of " << expectedOffset; 1359 1360 // Match strides in result memref type and in static_strides attribute. 1361 for (auto &en : llvm::enumerate(llvm::zip( 1362 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1363 int64_t resultStride = std::get<0>(en.value()); 1364 int64_t expectedStride = std::get<1>(en.value()); 1365 if (resultStride != expectedStride) 1366 return op.emitError("expected result type with stride = ") 1367 << expectedStride << " instead of " << resultStride 1368 << " in dim = " << en.index(); 1369 } 1370 } 1371 return success(); 1372 } 1373 1374 //===----------------------------------------------------------------------===// 1375 // ReshapeOp 1376 //===----------------------------------------------------------------------===// 1377 1378 static LogicalResult verify(ReshapeOp op) { 1379 Type operandType = op.source().getType(); 1380 Type resultType = op.result().getType(); 1381 1382 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1383 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1384 if (operandElementType != resultElementType) 1385 return op.emitOpError("element types of source and destination memref " 1386 "types should be the same"); 1387 1388 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1389 if (!operandMemRefType.getAffineMaps().empty()) 1390 return op.emitOpError( 1391 "source memref type should have identity affine map"); 1392 1393 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1394 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1395 if (resultMemRefType) { 1396 if (!resultMemRefType.getAffineMaps().empty()) 1397 return op.emitOpError( 1398 "result memref type should have identity affine map"); 1399 if (shapeSize == ShapedType::kDynamicSize) 1400 return op.emitOpError("cannot use shape operand with dynamic length to " 1401 "reshape to statically-ranked memref type"); 1402 if (shapeSize != resultMemRefType.getRank()) 1403 return op.emitOpError( 1404 "length of shape operand differs from the result's memref rank"); 1405 } 1406 return success(); 1407 } 1408 1409 //===----------------------------------------------------------------------===// 1410 // StoreOp 1411 //===----------------------------------------------------------------------===// 1412 1413 static LogicalResult verify(StoreOp op) { 1414 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1415 return op.emitOpError("store index operand count not equal to memref rank"); 1416 1417 return success(); 1418 } 1419 1420 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1421 SmallVectorImpl<OpFoldResult> &results) { 1422 /// store(memrefcast) -> store 1423 return foldMemRefCast(*this); 1424 } 1425 1426 //===----------------------------------------------------------------------===// 1427 // SubViewOp 1428 //===----------------------------------------------------------------------===// 1429 1430 namespace { 1431 /// Helpers to write more idiomatic operations. 1432 namespace saturated_arith { 1433 struct Wrapper { 1434 explicit Wrapper(int64_t v) : v(v) {} 1435 operator int64_t() { return v; } 1436 int64_t v; 1437 }; 1438 Wrapper operator+(Wrapper a, int64_t b) { 1439 if (ShapedType::isDynamicStrideOrOffset(a) || 1440 ShapedType::isDynamicStrideOrOffset(b)) 1441 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1442 return Wrapper(a.v + b); 1443 } 1444 Wrapper operator*(Wrapper a, int64_t b) { 1445 if (ShapedType::isDynamicStrideOrOffset(a) || 1446 ShapedType::isDynamicStrideOrOffset(b)) 1447 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1448 return Wrapper(a.v * b); 1449 } 1450 } // end namespace saturated_arith 1451 } // end namespace 1452 1453 /// A subview result type can be fully inferred from the source type and the 1454 /// static representation of offsets, sizes and strides. Special sentinels 1455 /// encode the dynamic case. 1456 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1457 ArrayRef<int64_t> leadingStaticOffsets, 1458 ArrayRef<int64_t> leadingStaticSizes, 1459 ArrayRef<int64_t> leadingStaticStrides) { 1460 // A subview may specify only a leading subset of offset/sizes/strides in 1461 // which case we complete with offset=0, sizes from memref type and strides=1. 1462 unsigned rank = sourceMemRefType.getRank(); 1463 assert(leadingStaticOffsets.size() <= rank && 1464 "unexpected leadingStaticOffsets overflow"); 1465 assert(leadingStaticSizes.size() <= rank && 1466 "unexpected leadingStaticSizes overflow"); 1467 assert(leadingStaticStrides.size() <= rank && 1468 "unexpected leadingStaticStrides overflow"); 1469 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1470 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1471 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1472 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1473 unsigned numTrailingSizes = rank - staticSizes.size(); 1474 unsigned numTrailingStrides = rank - staticStrides.size(); 1475 staticOffsets.append(numTrailingOffsets, 0); 1476 llvm::append_range(staticSizes, 1477 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1478 staticStrides.append(numTrailingStrides, 1); 1479 1480 // Extract source offset and strides. 1481 int64_t sourceOffset; 1482 SmallVector<int64_t, 4> sourceStrides; 1483 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1484 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1485 (void)res; 1486 1487 // Compute target offset whose value is: 1488 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1489 int64_t targetOffset = sourceOffset; 1490 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1491 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1492 using namespace saturated_arith; 1493 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1494 } 1495 1496 // Compute target stride whose value is: 1497 // `sourceStrides_i * staticStrides_i`. 1498 SmallVector<int64_t, 4> targetStrides; 1499 targetStrides.reserve(staticOffsets.size()); 1500 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1501 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1502 using namespace saturated_arith; 1503 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1504 } 1505 1506 // The type is now known. 1507 return MemRefType::get( 1508 staticSizes, sourceMemRefType.getElementType(), 1509 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1510 sourceMemRefType.getContext()), 1511 sourceMemRefType.getMemorySpace()); 1512 } 1513 1514 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1515 ArrayRef<OpFoldResult> leadingStaticOffsets, 1516 ArrayRef<OpFoldResult> leadingStaticSizes, 1517 ArrayRef<OpFoldResult> leadingStaticStrides) { 1518 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1519 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1520 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1521 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1522 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1523 ShapedType::kDynamicSize); 1524 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1525 staticStrides, ShapedType::kDynamicStrideOrOffset); 1526 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1527 staticSizes, staticStrides) 1528 .cast<MemRefType>(); 1529 } 1530 1531 Type SubViewOp::inferRankReducedResultType( 1532 unsigned resultRank, MemRefType sourceRankedTensorType, 1533 ArrayRef<int64_t> leadingStaticOffsets, 1534 ArrayRef<int64_t> leadingStaticSizes, 1535 ArrayRef<int64_t> leadingStaticStrides) { 1536 auto inferredType = 1537 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1538 leadingStaticSizes, leadingStaticStrides) 1539 .cast<MemRefType>(); 1540 assert(inferredType.getRank() >= resultRank && "expected "); 1541 int rankDiff = inferredType.getRank() - resultRank; 1542 if (rankDiff > 0) { 1543 auto shape = inferredType.getShape(); 1544 llvm::SmallDenseSet<unsigned> dimsToProject; 1545 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1546 SmallVector<int64_t> projectedShape; 1547 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1548 if (!dimsToProject.contains(pos)) 1549 projectedShape.push_back(shape[pos]); 1550 1551 AffineMap map; 1552 auto maps = inferredType.getAffineMaps(); 1553 if (!maps.empty() && maps.front()) 1554 map = getProjectedMap(maps.front(), dimsToProject); 1555 inferredType = 1556 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1557 inferredType.getMemorySpace()); 1558 } 1559 return inferredType; 1560 } 1561 1562 Type SubViewOp::inferRankReducedResultType( 1563 unsigned resultRank, MemRefType sourceRankedTensorType, 1564 ArrayRef<OpFoldResult> leadingStaticOffsets, 1565 ArrayRef<OpFoldResult> leadingStaticSizes, 1566 ArrayRef<OpFoldResult> leadingStaticStrides) { 1567 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1568 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1569 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1570 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1571 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1572 ShapedType::kDynamicSize); 1573 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1574 staticStrides, ShapedType::kDynamicStrideOrOffset); 1575 return SubViewOp::inferRankReducedResultType( 1576 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1577 staticStrides); 1578 } 1579 // Build a SubViewOp with mixed static and dynamic entries and custom result 1580 // type. If the type passed is nullptr, it is inferred. 1581 void SubViewOp::build(OpBuilder &b, OperationState &result, 1582 MemRefType resultType, Value source, 1583 ArrayRef<OpFoldResult> offsets, 1584 ArrayRef<OpFoldResult> sizes, 1585 ArrayRef<OpFoldResult> strides, 1586 ArrayRef<NamedAttribute> attrs) { 1587 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1588 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1589 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1590 ShapedType::kDynamicStrideOrOffset); 1591 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1592 ShapedType::kDynamicSize); 1593 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1594 ShapedType::kDynamicStrideOrOffset); 1595 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1596 // Structuring implementation this way avoids duplication between builders. 1597 if (!resultType) { 1598 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1599 staticSizes, staticStrides) 1600 .cast<MemRefType>(); 1601 } 1602 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1603 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1604 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1605 result.addAttributes(attrs); 1606 } 1607 1608 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1609 // type. 1610 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1611 ArrayRef<OpFoldResult> offsets, 1612 ArrayRef<OpFoldResult> sizes, 1613 ArrayRef<OpFoldResult> strides, 1614 ArrayRef<NamedAttribute> attrs) { 1615 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1616 } 1617 1618 // Build a SubViewOp with static entries and inferred result type. 1619 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1620 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1621 ArrayRef<int64_t> strides, 1622 ArrayRef<NamedAttribute> attrs) { 1623 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1624 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1625 return b.getI64IntegerAttr(v); 1626 })); 1627 SmallVector<OpFoldResult> sizeValues = 1628 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1629 return b.getI64IntegerAttr(v); 1630 })); 1631 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1632 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1633 return b.getI64IntegerAttr(v); 1634 })); 1635 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1636 } 1637 1638 // Build a SubViewOp with dynamic entries and custom result type. If the 1639 // type passed is nullptr, it is inferred. 1640 void SubViewOp::build(OpBuilder &b, OperationState &result, 1641 MemRefType resultType, Value source, 1642 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1643 ArrayRef<int64_t> strides, 1644 ArrayRef<NamedAttribute> attrs) { 1645 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1646 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1647 return b.getI64IntegerAttr(v); 1648 })); 1649 SmallVector<OpFoldResult> sizeValues = 1650 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1651 return b.getI64IntegerAttr(v); 1652 })); 1653 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1654 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1655 return b.getI64IntegerAttr(v); 1656 })); 1657 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1658 attrs); 1659 } 1660 1661 // Build a SubViewOp with dynamic entries and custom result type. If the type 1662 // passed is nullptr, it is inferred. 1663 void SubViewOp::build(OpBuilder &b, OperationState &result, 1664 MemRefType resultType, Value source, ValueRange offsets, 1665 ValueRange sizes, ValueRange strides, 1666 ArrayRef<NamedAttribute> attrs) { 1667 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1668 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1669 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1670 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1671 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1672 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1673 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1674 } 1675 1676 // Build a SubViewOp with dynamic entries and inferred result type. 1677 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1678 ValueRange offsets, ValueRange sizes, ValueRange strides, 1679 ArrayRef<NamedAttribute> attrs) { 1680 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1681 } 1682 1683 /// For ViewLikeOpInterface. 1684 Value SubViewOp::getViewSource() { return source(); } 1685 1686 enum SubViewVerificationResult { 1687 Success, 1688 RankTooLarge, 1689 SizeMismatch, 1690 ElemTypeMismatch, 1691 MemSpaceMismatch, 1692 AffineMapMismatch 1693 }; 1694 1695 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1696 /// This function is slight variant of `is subsequence` algorithm where 1697 /// not matching dimension must be 1. 1698 static SubViewVerificationResult 1699 isRankReducedType(Type originalType, Type candidateReducedType, 1700 std::string *errMsg = nullptr) { 1701 if (originalType == candidateReducedType) 1702 return SubViewVerificationResult::Success; 1703 if (!originalType.isa<MemRefType>()) 1704 return SubViewVerificationResult::Success; 1705 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1706 return SubViewVerificationResult::Success; 1707 1708 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1709 ShapedType candidateReducedShapedType = 1710 candidateReducedType.cast<ShapedType>(); 1711 1712 // Rank and size logic is valid for all ShapedTypes. 1713 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1714 ArrayRef<int64_t> candidateReducedShape = 1715 candidateReducedShapedType.getShape(); 1716 unsigned originalRank = originalShape.size(), 1717 candidateReducedRank = candidateReducedShape.size(); 1718 if (candidateReducedRank > originalRank) 1719 return SubViewVerificationResult::RankTooLarge; 1720 1721 auto optionalUnusedDimsMask = 1722 computeRankReductionMask(originalShape, candidateReducedShape); 1723 1724 // Sizes cannot be matched in case empty vector is returned. 1725 if (!optionalUnusedDimsMask.hasValue()) 1726 return SubViewVerificationResult::SizeMismatch; 1727 1728 if (originalShapedType.getElementType() != 1729 candidateReducedShapedType.getElementType()) 1730 return SubViewVerificationResult::ElemTypeMismatch; 1731 1732 // Strided layout logic is relevant for MemRefType only. 1733 MemRefType original = originalType.cast<MemRefType>(); 1734 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1735 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1736 return SubViewVerificationResult::MemSpaceMismatch; 1737 1738 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1739 auto inferredType = 1740 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1741 AffineMap candidateLayout; 1742 if (candidateReduced.getAffineMaps().empty()) 1743 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1744 else 1745 candidateLayout = candidateReduced.getAffineMaps().front(); 1746 assert(inferredType.getNumResults() == 1 && 1747 candidateLayout.getNumResults() == 1); 1748 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1749 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1750 if (errMsg) { 1751 llvm::raw_string_ostream os(*errMsg); 1752 os << "inferred type: " << inferredType; 1753 } 1754 return SubViewVerificationResult::AffineMapMismatch; 1755 } 1756 // Check that the difference of the affine maps simplifies to 0. 1757 AffineExpr diffExpr = 1758 inferredType.getResult(0) - candidateLayout.getResult(0); 1759 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1760 inferredType.getNumSymbols()); 1761 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1762 if (!(cst && cst.getValue() == 0)) { 1763 if (errMsg) { 1764 llvm::raw_string_ostream os(*errMsg); 1765 os << "inferred type: " << inferredType; 1766 } 1767 return SubViewVerificationResult::AffineMapMismatch; 1768 } 1769 return SubViewVerificationResult::Success; 1770 } 1771 1772 template <typename OpTy> 1773 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1774 OpTy op, Type expectedType, 1775 StringRef errMsg = "") { 1776 auto memrefType = expectedType.cast<ShapedType>(); 1777 switch (result) { 1778 case SubViewVerificationResult::Success: 1779 return success(); 1780 case SubViewVerificationResult::RankTooLarge: 1781 return op.emitError("expected result rank to be smaller or equal to ") 1782 << "the source rank. " << errMsg; 1783 case SubViewVerificationResult::SizeMismatch: 1784 return op.emitError("expected result type to be ") 1785 << expectedType 1786 << " or a rank-reduced version. (mismatch of result sizes) " 1787 << errMsg; 1788 case SubViewVerificationResult::ElemTypeMismatch: 1789 return op.emitError("expected result element type to be ") 1790 << memrefType.getElementType() << errMsg; 1791 case SubViewVerificationResult::MemSpaceMismatch: 1792 return op.emitError("expected result and source memory spaces to match.") 1793 << errMsg; 1794 case SubViewVerificationResult::AffineMapMismatch: 1795 return op.emitError("expected result type to be ") 1796 << expectedType 1797 << " or a rank-reduced version. (mismatch of result affine map) " 1798 << errMsg; 1799 } 1800 llvm_unreachable("unexpected subview verification result"); 1801 } 1802 1803 /// Verifier for SubViewOp. 1804 static LogicalResult verify(SubViewOp op) { 1805 MemRefType baseType = op.getSourceType(); 1806 MemRefType subViewType = op.getType(); 1807 1808 // The base memref and the view memref should be in the same memory space. 1809 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1810 return op.emitError("different memory spaces specified for base memref " 1811 "type ") 1812 << baseType << " and subview memref type " << subViewType; 1813 1814 // Verify that the base memref type has a strided layout map. 1815 if (!isStrided(baseType)) 1816 return op.emitError("base type ") << baseType << " is not strided"; 1817 1818 // Verify result type against inferred type. 1819 auto expectedType = SubViewOp::inferResultType( 1820 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1821 extractFromI64ArrayAttr(op.static_sizes()), 1822 extractFromI64ArrayAttr(op.static_strides())); 1823 1824 std::string errMsg; 1825 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1826 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1827 } 1828 1829 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1830 return os << "range " << range.offset << ":" << range.size << ":" 1831 << range.stride; 1832 } 1833 1834 /// Return the list of Range (i.e. offset, size, stride). Each Range 1835 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1836 /// with `b` at location `loc`. 1837 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1838 OpBuilder &b, Location loc) { 1839 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1840 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1841 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1842 SmallVector<Range, 8> res; 1843 unsigned rank = ranks[0]; 1844 res.reserve(rank); 1845 for (unsigned idx = 0; idx < rank; ++idx) { 1846 Value offset = 1847 op.isDynamicOffset(idx) 1848 ? op.getDynamicOffset(idx) 1849 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1850 Value size = op.isDynamicSize(idx) 1851 ? op.getDynamicSize(idx) 1852 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1853 Value stride = 1854 op.isDynamicStride(idx) 1855 ? op.getDynamicStride(idx) 1856 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1857 res.emplace_back(Range{offset, size, stride}); 1858 } 1859 return res; 1860 } 1861 1862 /// Infer the canonical type of the result of a subview operation. Returns a 1863 /// type with rank `resultRank` that is either the rank of the rank-reduced 1864 /// type, or the non-rank-reduced type. 1865 static MemRefType 1866 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1867 ArrayRef<OpFoldResult> mixedOffsets, 1868 ArrayRef<OpFoldResult> mixedSizes, 1869 ArrayRef<OpFoldResult> mixedStrides) { 1870 auto resultType = 1871 SubViewOp::inferRankReducedResultType( 1872 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1873 .cast<MemRefType>(); 1874 if (resultType.getRank() != resultRank) { 1875 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1876 mixedSizes, mixedStrides) 1877 .cast<MemRefType>(); 1878 } 1879 return resultType; 1880 } 1881 1882 namespace { 1883 /// Pattern to rewrite a subview op with MemRefCast arguments. 1884 /// This essentially pushes memref.cast past its consuming subview when 1885 /// `canFoldIntoConsumerOp` is true. 1886 /// 1887 /// Example: 1888 /// ``` 1889 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1890 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1891 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1892 /// ``` 1893 /// is rewritten into: 1894 /// ``` 1895 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1896 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1897 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1898 /// ``` 1899 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1900 public: 1901 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1902 1903 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1904 PatternRewriter &rewriter) const override { 1905 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1906 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1907 return matchPattern(operand, matchConstantIndex()); 1908 })) 1909 return failure(); 1910 1911 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1912 if (!castOp) 1913 return failure(); 1914 1915 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1916 return failure(); 1917 1918 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1919 /// the cast source operand type and the SubViewOp static information. This 1920 /// is the resulting type if the MemRefCastOp were folded. 1921 auto resultType = getCanonicalSubViewResultType( 1922 subViewOp.getType().getRank(), 1923 castOp.source().getType().cast<MemRefType>(), 1924 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1925 subViewOp.getMixedStrides()); 1926 Value newSubView = rewriter.create<SubViewOp>( 1927 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1928 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1929 subViewOp.static_sizes(), subViewOp.static_strides()); 1930 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1931 newSubView); 1932 return success(); 1933 } 1934 }; 1935 } // namespace 1936 1937 /// Return the canonical type of the result of a subview. 1938 struct SubViewReturnTypeCanonicalizer { 1939 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1940 ArrayRef<OpFoldResult> mixedSizes, 1941 ArrayRef<OpFoldResult> mixedStrides) { 1942 return getCanonicalSubViewResultType(op.getType().getRank(), 1943 op.getSourceType(), mixedOffsets, 1944 mixedSizes, mixedStrides); 1945 } 1946 }; 1947 1948 /// A canonicalizer wrapper to replace SubViewOps. 1949 struct SubViewCanonicalizer { 1950 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1951 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1952 } 1953 }; 1954 1955 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1956 MLIRContext *context) { 1957 results 1958 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1959 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1960 SubViewOpMemRefCastFolder>(context); 1961 } 1962 1963 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1964 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1965 auto sourceShapedType = source().getType().cast<ShapedType>(); 1966 1967 if (resultShapedType.hasStaticShape() && 1968 resultShapedType == sourceShapedType) { 1969 return getViewSource(); 1970 } 1971 1972 return {}; 1973 } 1974 1975 //===----------------------------------------------------------------------===// 1976 // TensorLoadOp 1977 //===----------------------------------------------------------------------===// 1978 1979 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1980 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1981 // Approximate alias analysis by conservatively folding only when no there 1982 // is no interleaved operation. 1983 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1984 bufferCast->getNextNode() == this->getOperation()) 1985 return bufferCast.tensor(); 1986 return {}; 1987 } 1988 1989 //===----------------------------------------------------------------------===// 1990 // TransposeOp 1991 //===----------------------------------------------------------------------===// 1992 1993 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1994 static MemRefType inferTransposeResultType(MemRefType memRefType, 1995 AffineMap permutationMap) { 1996 auto rank = memRefType.getRank(); 1997 auto originalSizes = memRefType.getShape(); 1998 // Compute permuted sizes. 1999 SmallVector<int64_t, 4> sizes(rank, 0); 2000 for (auto en : llvm::enumerate(permutationMap.getResults())) 2001 sizes[en.index()] = 2002 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2003 2004 // Compute permuted strides. 2005 int64_t offset; 2006 SmallVector<int64_t, 4> strides; 2007 auto res = getStridesAndOffset(memRefType, strides, offset); 2008 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2009 (void)res; 2010 auto map = 2011 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2012 map = permutationMap ? map.compose(permutationMap) : map; 2013 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2014 } 2015 2016 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2017 AffineMapAttr permutation, 2018 ArrayRef<NamedAttribute> attrs) { 2019 auto permutationMap = permutation.getValue(); 2020 assert(permutationMap); 2021 2022 auto memRefType = in.getType().cast<MemRefType>(); 2023 // Compute result type. 2024 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2025 2026 build(b, result, resultType, in, attrs); 2027 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2028 } 2029 2030 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2031 static void print(OpAsmPrinter &p, TransposeOp op) { 2032 p << "memref.transpose " << op.in() << " " << op.permutation(); 2033 p.printOptionalAttrDict(op->getAttrs(), 2034 {TransposeOp::getPermutationAttrName()}); 2035 p << " : " << op.in().getType() << " to " << op.getType(); 2036 } 2037 2038 static ParseResult parseTransposeOp(OpAsmParser &parser, 2039 OperationState &result) { 2040 OpAsmParser::OperandType in; 2041 AffineMap permutation; 2042 MemRefType srcType, dstType; 2043 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2044 parser.parseOptionalAttrDict(result.attributes) || 2045 parser.parseColonType(srcType) || 2046 parser.resolveOperand(in, srcType, result.operands) || 2047 parser.parseKeywordType("to", dstType) || 2048 parser.addTypeToList(dstType, result.types)) 2049 return failure(); 2050 2051 result.addAttribute(TransposeOp::getPermutationAttrName(), 2052 AffineMapAttr::get(permutation)); 2053 return success(); 2054 } 2055 2056 static LogicalResult verify(TransposeOp op) { 2057 if (!op.permutation().isPermutation()) 2058 return op.emitOpError("expected a permutation map"); 2059 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2060 return op.emitOpError( 2061 "expected a permutation map of same rank as the input"); 2062 2063 auto srcType = op.in().getType().cast<MemRefType>(); 2064 auto dstType = op.getType().cast<MemRefType>(); 2065 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2066 if (dstType != transposedType) 2067 return op.emitOpError("output type ") 2068 << dstType << " does not match transposed input type " << srcType 2069 << ", " << transposedType; 2070 return success(); 2071 } 2072 2073 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2074 if (succeeded(foldMemRefCast(*this))) 2075 return getResult(); 2076 return {}; 2077 } 2078 2079 //===----------------------------------------------------------------------===// 2080 // ViewOp 2081 //===----------------------------------------------------------------------===// 2082 2083 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2084 OpAsmParser::OperandType srcInfo; 2085 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2086 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2087 auto indexType = parser.getBuilder().getIndexType(); 2088 Type srcType, dstType; 2089 llvm::SMLoc offsetLoc; 2090 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2091 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2092 return failure(); 2093 2094 if (offsetInfo.size() != 1) 2095 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2096 2097 return failure( 2098 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2099 parser.parseOptionalAttrDict(result.attributes) || 2100 parser.parseColonType(srcType) || 2101 parser.resolveOperand(srcInfo, srcType, result.operands) || 2102 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2103 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2104 parser.parseKeywordType("to", dstType) || 2105 parser.addTypeToList(dstType, result.types)); 2106 } 2107 2108 static void print(OpAsmPrinter &p, ViewOp op) { 2109 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2110 p.printOperand(op.byte_shift()); 2111 p << "][" << op.sizes() << ']'; 2112 p.printOptionalAttrDict(op->getAttrs()); 2113 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2114 } 2115 2116 static LogicalResult verify(ViewOp op) { 2117 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2118 auto viewType = op.getType(); 2119 2120 // The base memref should have identity layout map (or none). 2121 if (baseType.getAffineMaps().size() > 1 || 2122 (baseType.getAffineMaps().size() == 1 && 2123 !baseType.getAffineMaps()[0].isIdentity())) 2124 return op.emitError("unsupported map for base memref type ") << baseType; 2125 2126 // The result memref should have identity layout map (or none). 2127 if (viewType.getAffineMaps().size() > 1 || 2128 (viewType.getAffineMaps().size() == 1 && 2129 !viewType.getAffineMaps()[0].isIdentity())) 2130 return op.emitError("unsupported map for result memref type ") << viewType; 2131 2132 // The base memref and the view memref should be in the same memory space. 2133 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2134 return op.emitError("different memory spaces specified for base memref " 2135 "type ") 2136 << baseType << " and view memref type " << viewType; 2137 2138 // Verify that we have the correct number of sizes for the result type. 2139 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2140 if (op.sizes().size() != numDynamicDims) 2141 return op.emitError("incorrect number of size operands for type ") 2142 << viewType; 2143 2144 return success(); 2145 } 2146 2147 Value ViewOp::getViewSource() { return source(); } 2148 2149 namespace { 2150 2151 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2152 using OpRewritePattern<ViewOp>::OpRewritePattern; 2153 2154 LogicalResult matchAndRewrite(ViewOp viewOp, 2155 PatternRewriter &rewriter) const override { 2156 // Return if none of the operands are constants. 2157 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2158 return matchPattern(operand, matchConstantIndex()); 2159 })) 2160 return failure(); 2161 2162 // Get result memref type. 2163 auto memrefType = viewOp.getType(); 2164 2165 // Get offset from old memref view type 'memRefType'. 2166 int64_t oldOffset; 2167 SmallVector<int64_t, 4> oldStrides; 2168 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2169 return failure(); 2170 assert(oldOffset == 0 && "Expected 0 offset"); 2171 2172 SmallVector<Value, 4> newOperands; 2173 2174 // Offset cannot be folded into result type. 2175 2176 // Fold any dynamic dim operands which are produced by a constant. 2177 SmallVector<int64_t, 4> newShapeConstants; 2178 newShapeConstants.reserve(memrefType.getRank()); 2179 2180 unsigned dynamicDimPos = 0; 2181 unsigned rank = memrefType.getRank(); 2182 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2183 int64_t dimSize = memrefType.getDimSize(dim); 2184 // If this is already static dimension, keep it. 2185 if (!ShapedType::isDynamic(dimSize)) { 2186 newShapeConstants.push_back(dimSize); 2187 continue; 2188 } 2189 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2190 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2191 // Dynamic shape dimension will be folded. 2192 newShapeConstants.push_back(constantIndexOp.getValue()); 2193 } else { 2194 // Dynamic shape dimension not folded; copy operand from old memref. 2195 newShapeConstants.push_back(dimSize); 2196 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2197 } 2198 dynamicDimPos++; 2199 } 2200 2201 // Create new memref type with constant folded dims. 2202 MemRefType newMemRefType = 2203 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2204 // Nothing new, don't fold. 2205 if (newMemRefType == memrefType) 2206 return failure(); 2207 2208 // Create new ViewOp. 2209 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2210 viewOp.getOperand(0), 2211 viewOp.byte_shift(), newOperands); 2212 // Insert a cast so we have the same type as the old memref type. 2213 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2214 return success(); 2215 } 2216 }; 2217 2218 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2219 using OpRewritePattern<ViewOp>::OpRewritePattern; 2220 2221 LogicalResult matchAndRewrite(ViewOp viewOp, 2222 PatternRewriter &rewriter) const override { 2223 Value memrefOperand = viewOp.getOperand(0); 2224 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2225 if (!memrefCastOp) 2226 return failure(); 2227 Value allocOperand = memrefCastOp.getOperand(); 2228 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2229 if (!allocOp) 2230 return failure(); 2231 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2232 viewOp.byte_shift(), viewOp.sizes()); 2233 return success(); 2234 } 2235 }; 2236 2237 } // end anonymous namespace 2238 2239 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2240 MLIRContext *context) { 2241 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2242 } 2243 2244 //===----------------------------------------------------------------------===// 2245 // TableGen'd op method definitions 2246 //===----------------------------------------------------------------------===// 2247 2248 #define GET_OP_CLASSES 2249 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2250