1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/ViewLikeInterface.h" 22 #include "llvm/ADT/STLExtras.h" 23 24 using namespace mlir; 25 using namespace mlir::memref; 26 27 /// Materialize a single constant operation from a given attribute value with 28 /// the desired resultant type. 29 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 30 Attribute value, Type type, 31 Location loc) { 32 return builder.create<mlir::ConstantOp>(loc, type, value); 33 } 34 35 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 36 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 37 return llvm::to_vector<4>( 38 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 39 return a.cast<IntegerAttr>().getInt(); 40 })); 41 } 42 43 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 44 /// it is a Value or into `staticVec` if it is an IntegerAttr. 45 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 47 /// come from an AttrSizedOperandSegments trait. 48 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 49 SmallVectorImpl<Value> &dynamicVec, 50 SmallVectorImpl<int64_t> &staticVec, 51 int64_t sentinel) { 52 if (auto v = ofr.dyn_cast<Value>()) { 53 dynamicVec.push_back(v); 54 staticVec.push_back(sentinel); 55 return; 56 } 57 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 58 staticVec.push_back(apInt.getSExtValue()); 59 } 60 61 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 62 SmallVectorImpl<Value> &dynamicVec, 63 SmallVectorImpl<int64_t> &staticVec, 64 int64_t sentinel) { 65 for (auto ofr : ofrs) 66 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Common canonicalization pattern support logic 71 //===----------------------------------------------------------------------===// 72 73 /// This is a common class used for patterns of the form 74 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 75 /// into the root operation directly. 76 static LogicalResult foldMemRefCast(Operation *op) { 77 bool folded = false; 78 for (OpOperand &operand : op->getOpOperands()) { 79 auto cast = operand.get().getDefiningOp<CastOp>(); 80 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 81 operand.set(cast.getOperand()); 82 folded = true; 83 } 84 } 85 return success(folded); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // Helpers for GlobalOp 90 //===----------------------------------------------------------------------===// 91 92 static Type getTensorTypeFromMemRefType(Type type) { 93 if (auto memref = type.dyn_cast<MemRefType>()) 94 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 95 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 96 return UnrankedTensorType::get(memref.getElementType()); 97 return NoneType::get(type.getContext()); 98 } 99 100 //===----------------------------------------------------------------------===// 101 // AllocOp / AllocaOp 102 //===----------------------------------------------------------------------===// 103 104 template <typename AllocLikeOp> 105 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 106 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 107 "applies to only alloc or alloca"); 108 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 109 if (!memRefType) 110 return op.emitOpError("result must be a memref"); 111 112 if (static_cast<int64_t>(op.dynamicSizes().size()) != 113 memRefType.getNumDynamicDims()) 114 return op.emitOpError("dimension operand count does not equal memref " 115 "dynamic dimension count"); 116 117 unsigned numSymbols = 0; 118 if (!memRefType.getAffineMaps().empty()) 119 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 120 if (op.symbolOperands().size() != numSymbols) 121 return op.emitOpError( 122 "symbol operand count does not equal memref symbol count"); 123 124 return success(); 125 } 126 127 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 128 129 static LogicalResult verify(AllocaOp op) { 130 // An alloca op needs to have an ancestor with an allocation scope trait. 131 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 132 return op.emitOpError( 133 "requires an ancestor op with AutomaticAllocationScope trait"); 134 135 return verifyAllocLikeOp(op); 136 } 137 138 namespace { 139 /// Fold constant dimensions into an alloc like operation. 140 template <typename AllocLikeOp> 141 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 142 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 143 144 LogicalResult matchAndRewrite(AllocLikeOp alloc, 145 PatternRewriter &rewriter) const override { 146 // Check to see if any dimensions operands are constants. If so, we can 147 // substitute and drop them. 148 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 149 return matchPattern(operand, matchConstantIndex()); 150 })) 151 return failure(); 152 153 auto memrefType = alloc.getType(); 154 155 // Ok, we have one or more constant operands. Collect the non-constant ones 156 // and keep track of the resultant memref type to build. 157 SmallVector<int64_t, 4> newShapeConstants; 158 newShapeConstants.reserve(memrefType.getRank()); 159 SmallVector<Value, 4> newOperands; 160 161 unsigned dynamicDimPos = 0; 162 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 163 int64_t dimSize = memrefType.getDimSize(dim); 164 // If this is already static dimension, keep it. 165 if (dimSize != -1) { 166 newShapeConstants.push_back(dimSize); 167 continue; 168 } 169 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 170 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 171 // Dynamic shape dimension will be folded. 172 newShapeConstants.push_back(constantIndexOp.getValue()); 173 } else { 174 // Dynamic shape dimension not folded; copy operand from old memref. 175 newShapeConstants.push_back(-1); 176 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 177 } 178 dynamicDimPos++; 179 } 180 181 // Create new memref type (which will have fewer dynamic dimensions). 182 MemRefType newMemRefType = 183 MemRefType::Builder(memrefType).setShape(newShapeConstants); 184 assert(static_cast<int64_t>(newOperands.size()) == 185 newMemRefType.getNumDynamicDims()); 186 187 // Create and insert the alloc op for the new memref. 188 auto newAlloc = rewriter.create<AllocLikeOp>( 189 alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); 190 // Insert a cast so we have the same type as the old alloc. 191 auto resultCast = 192 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 193 194 rewriter.replaceOp(alloc, {resultCast}); 195 return success(); 196 } 197 }; 198 199 /// Fold alloc operations with no users or only store and dealloc uses. 200 template <typename T> 201 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 202 using OpRewritePattern<T>::OpRewritePattern; 203 204 LogicalResult matchAndRewrite(T alloc, 205 PatternRewriter &rewriter) const override { 206 if (llvm::any_of(alloc->getUsers(), [](Operation *op) { 207 return !isa<StoreOp, DeallocOp>(op); 208 })) 209 return failure(); 210 211 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 212 rewriter.eraseOp(user); 213 214 rewriter.eraseOp(alloc); 215 return success(); 216 } 217 }; 218 } // end anonymous namespace. 219 220 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 221 MLIRContext *context) { 222 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 223 } 224 225 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 226 MLIRContext *context) { 227 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 228 context); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // AssumeAlignmentOp 233 //===----------------------------------------------------------------------===// 234 235 static LogicalResult verify(AssumeAlignmentOp op) { 236 unsigned alignment = op.alignment(); 237 if (!llvm::isPowerOf2_32(alignment)) 238 return op.emitOpError("alignment must be power of 2"); 239 return success(); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // BufferCastOp 244 //===----------------------------------------------------------------------===// 245 246 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 247 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 248 if (tensorLoad.memref().getType() == getType()) 249 return tensorLoad.memref(); 250 return {}; 251 } 252 253 namespace { 254 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 255 struct BufferCast : public OpRewritePattern<BufferCastOp> { 256 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 257 258 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 259 PatternRewriter &rewriter) const final { 260 auto tensorCastOperand = 261 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 262 if (!tensorCastOperand) 263 return failure(); 264 auto srcTensorType = 265 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 266 if (!srcTensorType) 267 return failure(); 268 auto memrefType = MemRefType::get(srcTensorType.getShape(), 269 srcTensorType.getElementType()); 270 Value memref = rewriter.create<BufferCastOp>( 271 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 272 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 273 memref); 274 return success(); 275 } 276 }; 277 278 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 279 /// type mismatches prevent `BufferCastOp::fold` to kick in. 280 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 281 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 282 283 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 284 PatternRewriter &rewriter) const final { 285 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 286 // Bail unless we have a tensor_load + memref.buffer_cast with different 287 // types. `BufferCastOp::fold` handles the same type case. 288 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 289 return failure(); 290 // If types are not cast-compatible, bail. 291 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 292 bufferCast.getType())) 293 return failure(); 294 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 295 tensorLoad.memref()); 296 return success(); 297 } 298 }; 299 300 } // namespace 301 302 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 303 MLIRContext *context) { 304 results.add<BufferCast, TensorLoadToMemRef>(context); 305 } 306 307 //===----------------------------------------------------------------------===// 308 // CastOp 309 //===----------------------------------------------------------------------===// 310 311 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 312 /// source memref. This is useful to to fold a memref.cast into a consuming op 313 /// and implement canonicalization patterns for ops in different dialects that 314 /// may consume the results of memref.cast operations. Such foldable memref.cast 315 /// operations are typically inserted as `view` and `subview` ops are 316 /// canonicalized, to preserve the type compatibility of their uses. 317 /// 318 /// Returns true when all conditions are met: 319 /// 1. source and result are ranked memrefs with strided semantics and same 320 /// element type and rank. 321 /// 2. each of the source's size, offset or stride has more static information 322 /// than the corresponding result's size, offset or stride. 323 /// 324 /// Example 1: 325 /// ```mlir 326 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 327 /// %2 = consumer %1 ... : memref<?x?xf32> ... 328 /// ``` 329 /// 330 /// may fold into: 331 /// 332 /// ```mlir 333 /// %2 = consumer %0 ... : memref<8x16xf32> ... 334 /// ``` 335 /// 336 /// Example 2: 337 /// ``` 338 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 339 /// to memref<?x?xf32> 340 /// consumer %1 : memref<?x?xf32> ... 341 /// ``` 342 /// 343 /// may fold into: 344 /// 345 /// ``` 346 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 347 /// ``` 348 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 349 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 350 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 351 352 // Requires ranked MemRefType. 353 if (!sourceType || !resultType) 354 return false; 355 356 // Requires same elemental type. 357 if (sourceType.getElementType() != resultType.getElementType()) 358 return false; 359 360 // Requires same rank. 361 if (sourceType.getRank() != resultType.getRank()) 362 return false; 363 364 // Only fold casts between strided memref forms. 365 int64_t sourceOffset, resultOffset; 366 SmallVector<int64_t, 4> sourceStrides, resultStrides; 367 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 368 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 369 return false; 370 371 // If cast is towards more static sizes along any dimension, don't fold. 372 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 373 auto ss = std::get<0>(it), st = std::get<1>(it); 374 if (ss != st) 375 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 376 return false; 377 } 378 379 // If cast is towards more static offset along any dimension, don't fold. 380 if (sourceOffset != resultOffset) 381 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 382 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 383 return false; 384 385 // If cast is towards more static strides along any dimension, don't fold. 386 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 387 auto ss = std::get<0>(it), st = std::get<1>(it); 388 if (ss != st) 389 if (MemRefType::isDynamicStrideOrOffset(ss) && 390 !MemRefType::isDynamicStrideOrOffset(st)) 391 return false; 392 } 393 394 return true; 395 } 396 397 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 398 if (inputs.size() != 1 || outputs.size() != 1) 399 return false; 400 Type a = inputs.front(), b = outputs.front(); 401 auto aT = a.dyn_cast<MemRefType>(); 402 auto bT = b.dyn_cast<MemRefType>(); 403 404 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 405 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 406 407 if (aT && bT) { 408 if (aT.getElementType() != bT.getElementType()) 409 return false; 410 if (aT.getAffineMaps() != bT.getAffineMaps()) { 411 int64_t aOffset, bOffset; 412 SmallVector<int64_t, 4> aStrides, bStrides; 413 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 414 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 415 aStrides.size() != bStrides.size()) 416 return false; 417 418 // Strides along a dimension/offset are compatible if the value in the 419 // source memref is static and the value in the target memref is the 420 // same. They are also compatible if either one is dynamic (see 421 // description of MemRefCastOp for details). 422 auto checkCompatible = [](int64_t a, int64_t b) { 423 return (a == MemRefType::getDynamicStrideOrOffset() || 424 b == MemRefType::getDynamicStrideOrOffset() || a == b); 425 }; 426 if (!checkCompatible(aOffset, bOffset)) 427 return false; 428 for (auto aStride : enumerate(aStrides)) 429 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 430 return false; 431 } 432 if (aT.getMemorySpace() != bT.getMemorySpace()) 433 return false; 434 435 // They must have the same rank, and any specified dimensions must match. 436 if (aT.getRank() != bT.getRank()) 437 return false; 438 439 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 440 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 441 if (aDim != -1 && bDim != -1 && aDim != bDim) 442 return false; 443 } 444 return true; 445 } else { 446 if (!aT && !uaT) 447 return false; 448 if (!bT && !ubT) 449 return false; 450 // Unranked to unranked casting is unsupported 451 if (uaT && ubT) 452 return false; 453 454 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 455 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 456 if (aEltType != bEltType) 457 return false; 458 459 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 460 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 461 if (aMemSpace != bMemSpace) 462 return false; 463 464 return true; 465 } 466 467 return false; 468 } 469 470 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 471 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 472 } 473 474 //===----------------------------------------------------------------------===// 475 // CloneOp 476 //===----------------------------------------------------------------------===// 477 478 void CloneOp::getEffects( 479 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 480 &effects) { 481 effects.emplace_back(MemoryEffects::Read::get(), input(), 482 SideEffects::DefaultResource::get()); 483 effects.emplace_back(MemoryEffects::Write::get(), output(), 484 SideEffects::DefaultResource::get()); 485 } 486 487 namespace { 488 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 489 /// by other Dealloc operations. 490 struct SimplifyClones : public OpRewritePattern<CloneOp> { 491 using OpRewritePattern<CloneOp>::OpRewritePattern; 492 493 LogicalResult matchAndRewrite(CloneOp cloneOp, 494 PatternRewriter &rewriter) const override { 495 if (cloneOp.use_empty()) { 496 rewriter.eraseOp(cloneOp); 497 return success(); 498 } 499 500 Value source = cloneOp.input(); 501 502 // This only finds dealloc operations for the immediate value. It should 503 // also consider aliases. That would also make the safety check below 504 // redundant. 505 Operation *cloneDeallocOp = findDealloc(cloneOp.output()); 506 Operation *sourceDeallocOp = findDealloc(source); 507 508 // If both are deallocated in the same block, their in-block lifetimes 509 // might not fully overlap, so we cannot decide which one to drop. 510 if (cloneDeallocOp && sourceDeallocOp && 511 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 512 return failure(); 513 514 Block *currentBlock = cloneOp->getBlock(); 515 Operation *redundantDealloc = nullptr; 516 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 517 redundantDealloc = cloneDeallocOp; 518 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 519 redundantDealloc = sourceDeallocOp; 520 } 521 522 if (!redundantDealloc) 523 return failure(); 524 525 // Safety check that there are no other deallocations inbetween 526 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 527 // of source before the uses of the clone. With alias information, we could 528 // restrict this to only fail of the dealloc's operand is an alias 529 // of the source. 530 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 531 pos = pos->getNextNode()) { 532 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 533 if (!effectInterface) 534 continue; 535 if (effectInterface.hasEffect<MemoryEffects::Free>()) 536 return failure(); 537 } 538 539 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 540 source); 541 rewriter.eraseOp(redundantDealloc); 542 return success(); 543 } 544 }; 545 546 } // end anonymous namespace. 547 548 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 549 MLIRContext *context) { 550 results.insert<SimplifyClones>(context); 551 } 552 553 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 554 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 555 } 556 557 //===----------------------------------------------------------------------===// 558 // DeallocOp 559 //===----------------------------------------------------------------------===// 560 561 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 562 SmallVectorImpl<OpFoldResult> &results) { 563 /// dealloc(memrefcast) -> dealloc 564 return foldMemRefCast(*this); 565 } 566 567 //===----------------------------------------------------------------------===// 568 // DimOp 569 //===----------------------------------------------------------------------===// 570 571 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 572 int64_t index) { 573 auto loc = result.location; 574 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 575 build(builder, result, memref, indexValue); 576 } 577 578 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 579 Value index) { 580 auto indexTy = builder.getIndexType(); 581 build(builder, result, indexTy, memref, index); 582 } 583 584 Optional<int64_t> DimOp::getConstantIndex() { 585 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 586 return constantOp.getValue().cast<IntegerAttr>().getInt(); 587 return {}; 588 } 589 590 static LogicalResult verify(DimOp op) { 591 // Assume unknown index to be in range. 592 Optional<int64_t> index = op.getConstantIndex(); 593 if (!index.hasValue()) 594 return success(); 595 596 // Check that constant index is not knowingly out of range. 597 auto type = op.memrefOrTensor().getType(); 598 if (auto memrefType = type.dyn_cast<MemRefType>()) { 599 if (index.getValue() >= memrefType.getRank()) 600 return op.emitOpError("index is out of range"); 601 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 602 if (index.getValue() >= tensorType.getRank()) 603 return op.emitOpError("index is out of range"); 604 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 605 // Assume index to be in range. 606 } else { 607 llvm_unreachable("expected operand with memref type"); 608 } 609 return success(); 610 } 611 612 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 613 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 614 615 // All forms of folding require a known index. 616 if (!index) 617 return {}; 618 619 auto argTy = memrefOrTensor().getType(); 620 // Fold if the shape extent along the given index is known. 621 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 622 // Folding for unranked types (UnrankedMemRefType) is not supported. 623 if (!shapedTy.hasRank()) 624 return {}; 625 if (!shapedTy.isDynamicDim(index.getInt())) { 626 Builder builder(getContext()); 627 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 628 } 629 } 630 631 Operation *definingOp = memrefOrTensor().getDefiningOp(); 632 633 // dim(memref.tensor_load(memref)) -> dim(memref) 634 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 635 setOperand(0, tensorLoadOp.memref()); 636 return getResult(); 637 } 638 639 // Fold dim to the operand of tensor.generate. 640 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 641 auto resultType = 642 fromElements.getResult().getType().cast<RankedTensorType>(); 643 // The case where the type encodes the size of the dimension is handled 644 // above. 645 assert(resultType.getShape()[index.getInt()] == 646 RankedTensorType::kDynamicSize); 647 648 // Find the operand of the fromElements that corresponds to this index. 649 auto dynExtents = fromElements.dynamicExtents().begin(); 650 for (auto dim : resultType.getShape().take_front(index.getInt())) 651 if (dim == RankedTensorType::kDynamicSize) 652 dynExtents++; 653 654 return Value{*dynExtents}; 655 } 656 657 // The size at the given index is now known to be a dynamic size. 658 unsigned unsignedIndex = index.getValue().getZExtValue(); 659 660 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 661 assert(subtensor.isDynamicSize(unsignedIndex) && 662 "Expected dynamic subtensor size"); 663 return subtensor.getDynamicSize(unsignedIndex); 664 } 665 666 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 667 auto memrefType = argTy.dyn_cast<MemRefType>(); 668 if (!memrefType) 669 return {}; 670 671 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 672 return *(alloc.getDynamicSizes().begin() + 673 memrefType.getDynamicDimIndex(unsignedIndex)); 674 675 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 676 return *(alloca.getDynamicSizes().begin() + 677 memrefType.getDynamicDimIndex(unsignedIndex)); 678 679 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 680 return *(view.getDynamicSizes().begin() + 681 memrefType.getDynamicDimIndex(unsignedIndex)); 682 683 if (auto sizeInterface = 684 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 685 assert(sizeInterface.isDynamicSize(unsignedIndex) && 686 "Expected dynamic subview size"); 687 return sizeInterface.getDynamicSize(unsignedIndex); 688 } 689 690 // dim(memrefcast) -> dim 691 if (succeeded(foldMemRefCast(*this))) 692 return getResult(); 693 694 return {}; 695 } 696 697 namespace { 698 /// Fold dim of a memref reshape operation to a load into the reshape's shape 699 /// operand. 700 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 701 using OpRewritePattern<DimOp>::OpRewritePattern; 702 703 LogicalResult matchAndRewrite(DimOp dim, 704 PatternRewriter &rewriter) const override { 705 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 706 707 if (!reshape) 708 return failure(); 709 710 // Place the load directly after the reshape to ensure that the shape memref 711 // was not mutated. 712 rewriter.setInsertionPointAfter(reshape); 713 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(), 714 llvm::makeArrayRef({dim.index()})); 715 return success(); 716 } 717 }; 718 719 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 720 template <typename CastOpTy> 721 struct DimOfCastOp : public OpRewritePattern<DimOp> { 722 using OpRewritePattern<DimOp>::OpRewritePattern; 723 724 LogicalResult matchAndRewrite(DimOp dimOp, 725 PatternRewriter &rewriter) const override { 726 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 727 if (!castOp) 728 return failure(); 729 Value newSource = castOp.getOperand(); 730 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 731 return success(); 732 } 733 }; 734 735 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th 736 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. 737 /// TODO(ravishankarm): This is better put as a interface utility method 738 /// somewhere, but that would imply the interface will depend on the `tensor` 739 /// dialect. Ideally maybe a utility method in the `tensor` dialect. 740 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, 741 int64_t dimIndex) { 742 unsigned resultNumber = result.getResultNumber(); 743 auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); 744 Location loc = result.getOwner()->getLoc(); 745 if (!shapedTypeOp) 746 return nullptr; 747 748 // The interface exposes two methods, one that returns the shape of all the 749 // results as `Value` and other that returns the shape as a list of 750 // `SmallVector<Value>`. The former takes precedence over the latter. So first 751 // check if the op implements the first interface method or the second, and 752 // get the value to use appropriately. 753 SmallVector<Value> reifiedResultShapes; 754 if (succeeded(shapedTypeOp.reifyReturnTypeShapes( 755 builder, result.getOwner()->getOperands(), reifiedResultShapes))) { 756 if (reifiedResultShapes.size() <= resultNumber) 757 return nullptr; 758 Value resultShape = reifiedResultShapes[resultNumber]; 759 auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); 760 if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) 761 return nullptr; 762 return builder.create<tensor::ExtractOp>( 763 loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); 764 } 765 766 SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; 767 if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( 768 builder, reifiedResultShapesPerDim))) 769 return nullptr; 770 if (reifiedResultShapesPerDim.size() <= resultNumber || 771 reifiedResultShapesPerDim[resultNumber].size() != 772 static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) 773 return nullptr; 774 OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; 775 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) 776 return builder.createOrFold<ConstantIndexOp>( 777 loc, attr.cast<IntegerAttr>().getInt()); 778 return valueOrAttr.get<Value>(); 779 } 780 781 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 782 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> { 783 using OpRewritePattern<DimOp>::OpRewritePattern; 784 785 LogicalResult matchAndRewrite(DimOp dimOp, 786 PatternRewriter &rewriter) const override { 787 OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); 788 if (!dimValue) 789 return failure(); 790 auto shapedTypeOp = 791 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 792 if (!shapedTypeOp) 793 return failure(); 794 795 Optional<int64_t> dimIndex = dimOp.getConstantIndex(); 796 if (!dimIndex) 797 return failure(); 798 Value replacement = 799 getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); 800 if (!replacement) 801 return failure(); 802 rewriter.replaceOp(dimOp, replacement); 803 return success(); 804 } 805 }; 806 } // end anonymous namespace. 807 808 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 809 MLIRContext *context) { 810 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 811 DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context); 812 } 813 814 // --------------------------------------------------------------------------- 815 // DmaStartOp 816 // --------------------------------------------------------------------------- 817 818 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 819 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 820 ValueRange destIndices, Value numElements, 821 Value tagMemRef, ValueRange tagIndices, Value stride, 822 Value elementsPerStride) { 823 result.addOperands(srcMemRef); 824 result.addOperands(srcIndices); 825 result.addOperands(destMemRef); 826 result.addOperands(destIndices); 827 result.addOperands({numElements, tagMemRef}); 828 result.addOperands(tagIndices); 829 if (stride) 830 result.addOperands({stride, elementsPerStride}); 831 } 832 833 void DmaStartOp::print(OpAsmPrinter &p) { 834 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 835 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 836 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 837 << ']'; 838 if (isStrided()) 839 p << ", " << getStride() << ", " << getNumElementsPerStride(); 840 841 p.printOptionalAttrDict((*this)->getAttrs()); 842 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 843 << ", " << getTagMemRef().getType(); 844 } 845 846 // Parse DmaStartOp. 847 // Ex: 848 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 849 // %tag[%index], %stride, %num_elt_per_stride : 850 // : memref<3076 x f32, 0>, 851 // memref<1024 x f32, 2>, 852 // memref<1 x i32> 853 // 854 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 855 OpAsmParser::OperandType srcMemRefInfo; 856 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 857 OpAsmParser::OperandType dstMemRefInfo; 858 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 859 OpAsmParser::OperandType numElementsInfo; 860 OpAsmParser::OperandType tagMemrefInfo; 861 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 862 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 863 864 SmallVector<Type, 3> types; 865 auto indexType = parser.getBuilder().getIndexType(); 866 867 // Parse and resolve the following list of operands: 868 // *) source memref followed by its indices (in square brackets). 869 // *) destination memref followed by its indices (in square brackets). 870 // *) dma size in KiB. 871 if (parser.parseOperand(srcMemRefInfo) || 872 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 873 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 874 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 875 parser.parseComma() || parser.parseOperand(numElementsInfo) || 876 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 877 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 878 return failure(); 879 880 // Parse optional stride and elements per stride. 881 if (parser.parseTrailingOperandList(strideInfo)) 882 return failure(); 883 884 bool isStrided = strideInfo.size() == 2; 885 if (!strideInfo.empty() && !isStrided) { 886 return parser.emitError(parser.getNameLoc(), 887 "expected two stride related operands"); 888 } 889 890 if (parser.parseColonTypeList(types)) 891 return failure(); 892 if (types.size() != 3) 893 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 894 895 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 896 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 897 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 898 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 899 // size should be an index. 900 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 901 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 902 // tag indices should be index. 903 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 904 return failure(); 905 906 if (isStrided) { 907 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 908 return failure(); 909 } 910 911 return success(); 912 } 913 914 LogicalResult DmaStartOp::verify() { 915 unsigned numOperands = getNumOperands(); 916 917 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 918 // the number of elements. 919 if (numOperands < 4) 920 return emitOpError("expected at least 4 operands"); 921 922 // Check types of operands. The order of these calls is important: the later 923 // calls rely on some type properties to compute the operand position. 924 // 1. Source memref. 925 if (!getSrcMemRef().getType().isa<MemRefType>()) 926 return emitOpError("expected source to be of memref type"); 927 if (numOperands < getSrcMemRefRank() + 4) 928 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 929 << " operands"; 930 if (!getSrcIndices().empty() && 931 !llvm::all_of(getSrcIndices().getTypes(), 932 [](Type t) { return t.isIndex(); })) 933 return emitOpError("expected source indices to be of index type"); 934 935 // 2. Destination memref. 936 if (!getDstMemRef().getType().isa<MemRefType>()) 937 return emitOpError("expected destination to be of memref type"); 938 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 939 if (numOperands < numExpectedOperands) 940 return emitOpError() << "expected at least " << numExpectedOperands 941 << " operands"; 942 if (!getDstIndices().empty() && 943 !llvm::all_of(getDstIndices().getTypes(), 944 [](Type t) { return t.isIndex(); })) 945 return emitOpError("expected destination indices to be of index type"); 946 947 // 3. Number of elements. 948 if (!getNumElements().getType().isIndex()) 949 return emitOpError("expected num elements to be of index type"); 950 951 // 4. Tag memref. 952 if (!getTagMemRef().getType().isa<MemRefType>()) 953 return emitOpError("expected tag to be of memref type"); 954 numExpectedOperands += getTagMemRefRank(); 955 if (numOperands < numExpectedOperands) 956 return emitOpError() << "expected at least " << numExpectedOperands 957 << " operands"; 958 if (!getTagIndices().empty() && 959 !llvm::all_of(getTagIndices().getTypes(), 960 [](Type t) { return t.isIndex(); })) 961 return emitOpError("expected tag indices to be of index type"); 962 963 // Optional stride-related operands must be either both present or both 964 // absent. 965 if (numOperands != numExpectedOperands && 966 numOperands != numExpectedOperands + 2) 967 return emitOpError("incorrect number of operands"); 968 969 // 5. Strides. 970 if (isStrided()) { 971 if (!getStride().getType().isIndex() || 972 !getNumElementsPerStride().getType().isIndex()) 973 return emitOpError( 974 "expected stride and num elements per stride to be of type index"); 975 } 976 977 return success(); 978 } 979 980 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 981 SmallVectorImpl<OpFoldResult> &results) { 982 /// dma_start(memrefcast) -> dma_start 983 return foldMemRefCast(*this); 984 } 985 986 // --------------------------------------------------------------------------- 987 // DmaWaitOp 988 // --------------------------------------------------------------------------- 989 990 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 991 Value tagMemRef, ValueRange tagIndices, 992 Value numElements) { 993 result.addOperands(tagMemRef); 994 result.addOperands(tagIndices); 995 result.addOperands(numElements); 996 } 997 998 void DmaWaitOp::print(OpAsmPrinter &p) { 999 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 1000 << "], " << getNumElements(); 1001 p.printOptionalAttrDict((*this)->getAttrs()); 1002 p << " : " << getTagMemRef().getType(); 1003 } 1004 1005 // Parse DmaWaitOp. 1006 // Eg: 1007 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 1008 // 1009 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 1010 OpAsmParser::OperandType tagMemrefInfo; 1011 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 1012 Type type; 1013 auto indexType = parser.getBuilder().getIndexType(); 1014 OpAsmParser::OperandType numElementsInfo; 1015 1016 // Parse tag memref, its indices, and dma size. 1017 if (parser.parseOperand(tagMemrefInfo) || 1018 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1019 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1020 parser.parseColonType(type) || 1021 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 1022 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 1023 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1024 return failure(); 1025 1026 return success(); 1027 } 1028 1029 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1030 SmallVectorImpl<OpFoldResult> &results) { 1031 /// dma_wait(memrefcast) -> dma_wait 1032 return foldMemRefCast(*this); 1033 } 1034 1035 LogicalResult DmaWaitOp::verify() { 1036 // Mandatory non-variadic operands are tag and the number of elements. 1037 if (getNumOperands() < 2) 1038 return emitOpError() << "expected at least 2 operands"; 1039 1040 // Check types of operands. The order of these calls is important: the later 1041 // calls rely on some type properties to compute the operand position. 1042 if (!getTagMemRef().getType().isa<MemRefType>()) 1043 return emitOpError() << "expected tag to be of memref type"; 1044 1045 if (getNumOperands() != 2 + getTagMemRefRank()) 1046 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1047 << " operands"; 1048 1049 if (!getTagIndices().empty() && 1050 !llvm::all_of(getTagIndices().getTypes(), 1051 [](Type t) { return t.isIndex(); })) 1052 return emitOpError() << "expected tag indices to be of index type"; 1053 1054 if (!getNumElements().getType().isIndex()) 1055 return emitOpError() 1056 << "expected the number of elements to be of index type"; 1057 1058 return success(); 1059 } 1060 1061 //===----------------------------------------------------------------------===// 1062 // GlobalOp 1063 //===----------------------------------------------------------------------===// 1064 1065 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1066 TypeAttr type, 1067 Attribute initialValue) { 1068 p << type; 1069 if (!op.isExternal()) { 1070 p << " = "; 1071 if (op.isUninitialized()) 1072 p << "uninitialized"; 1073 else 1074 p.printAttributeWithoutType(initialValue); 1075 } 1076 } 1077 1078 static ParseResult 1079 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1080 Attribute &initialValue) { 1081 Type type; 1082 if (parser.parseType(type)) 1083 return failure(); 1084 1085 auto memrefType = type.dyn_cast<MemRefType>(); 1086 if (!memrefType || !memrefType.hasStaticShape()) 1087 return parser.emitError(parser.getNameLoc()) 1088 << "type should be static shaped memref, but got " << type; 1089 typeAttr = TypeAttr::get(type); 1090 1091 if (parser.parseOptionalEqual()) 1092 return success(); 1093 1094 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1095 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1096 return success(); 1097 } 1098 1099 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1100 if (parser.parseAttribute(initialValue, tensorType)) 1101 return failure(); 1102 if (!initialValue.isa<ElementsAttr>()) 1103 return parser.emitError(parser.getNameLoc()) 1104 << "initial value should be a unit or elements attribute"; 1105 return success(); 1106 } 1107 1108 static LogicalResult verify(GlobalOp op) { 1109 auto memrefType = op.type().dyn_cast<MemRefType>(); 1110 if (!memrefType || !memrefType.hasStaticShape()) 1111 return op.emitOpError("type should be static shaped memref, but got ") 1112 << op.type(); 1113 1114 // Verify that the initial value, if present, is either a unit attribute or 1115 // an elements attribute. 1116 if (op.initial_value().hasValue()) { 1117 Attribute initValue = op.initial_value().getValue(); 1118 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1119 return op.emitOpError("initial value should be a unit or elements " 1120 "attribute, but got ") 1121 << initValue; 1122 1123 // Check that the type of the initial value is compatible with the type of 1124 // the global variable. 1125 if (initValue.isa<ElementsAttr>()) { 1126 Type initType = initValue.getType(); 1127 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1128 if (initType != tensorType) 1129 return op.emitOpError("initial value expected to be of type ") 1130 << tensorType << ", but was of type " << initType; 1131 } 1132 } 1133 1134 // TODO: verify visibility for declarations. 1135 return success(); 1136 } 1137 1138 //===----------------------------------------------------------------------===// 1139 // GetGlobalOp 1140 //===----------------------------------------------------------------------===// 1141 1142 LogicalResult 1143 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1144 // Verify that the result type is same as the type of the referenced 1145 // memref.global op. 1146 auto global = 1147 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1148 if (!global) 1149 return emitOpError("'") 1150 << name() << "' does not reference a valid global memref"; 1151 1152 Type resultType = result().getType(); 1153 if (global.type() != resultType) 1154 return emitOpError("result type ") 1155 << resultType << " does not match type " << global.type() 1156 << " of the global memref @" << name(); 1157 return success(); 1158 } 1159 1160 //===----------------------------------------------------------------------===// 1161 // LoadOp 1162 //===----------------------------------------------------------------------===// 1163 1164 static LogicalResult verify(LoadOp op) { 1165 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1166 return op.emitOpError("incorrect number of indices for load"); 1167 return success(); 1168 } 1169 1170 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1171 /// load(memrefcast) -> load 1172 if (succeeded(foldMemRefCast(*this))) 1173 return getResult(); 1174 return OpFoldResult(); 1175 } 1176 1177 namespace { 1178 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1179 /// corresponding tensor. 1180 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1181 using OpRewritePattern<LoadOp>::OpRewritePattern; 1182 1183 LogicalResult matchAndRewrite(LoadOp load, 1184 PatternRewriter &rewriter) const override { 1185 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1186 if (!buffercast) 1187 return failure(); 1188 1189 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1190 load.indices()); 1191 return success(); 1192 } 1193 }; 1194 } // end anonymous namespace. 1195 1196 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1197 MLIRContext *context) { 1198 results.add<LoadOfBufferCast>(context); 1199 } 1200 1201 //===----------------------------------------------------------------------===// 1202 // PrefetchOp 1203 //===----------------------------------------------------------------------===// 1204 1205 static void print(OpAsmPrinter &p, PrefetchOp op) { 1206 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1207 p.printOperands(op.indices()); 1208 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1209 p << ", locality<" << op.localityHint(); 1210 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1211 p.printOptionalAttrDict( 1212 op->getAttrs(), 1213 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1214 p << " : " << op.getMemRefType(); 1215 } 1216 1217 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1218 OperationState &result) { 1219 OpAsmParser::OperandType memrefInfo; 1220 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1221 IntegerAttr localityHint; 1222 MemRefType type; 1223 StringRef readOrWrite, cacheType; 1224 1225 auto indexTy = parser.getBuilder().getIndexType(); 1226 auto i32Type = parser.getBuilder().getIntegerType(32); 1227 if (parser.parseOperand(memrefInfo) || 1228 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1229 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1230 parser.parseComma() || parser.parseKeyword("locality") || 1231 parser.parseLess() || 1232 parser.parseAttribute(localityHint, i32Type, "localityHint", 1233 result.attributes) || 1234 parser.parseGreater() || parser.parseComma() || 1235 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1236 parser.resolveOperand(memrefInfo, type, result.operands) || 1237 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1238 return failure(); 1239 1240 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1241 return parser.emitError(parser.getNameLoc(), 1242 "rw specifier has to be 'read' or 'write'"); 1243 result.addAttribute( 1244 PrefetchOp::getIsWriteAttrName(), 1245 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1246 1247 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1248 return parser.emitError(parser.getNameLoc(), 1249 "cache type has to be 'data' or 'instr'"); 1250 1251 result.addAttribute( 1252 PrefetchOp::getIsDataCacheAttrName(), 1253 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1254 1255 return success(); 1256 } 1257 1258 static LogicalResult verify(PrefetchOp op) { 1259 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1260 return op.emitOpError("too few indices"); 1261 1262 return success(); 1263 } 1264 1265 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1266 SmallVectorImpl<OpFoldResult> &results) { 1267 // prefetch(memrefcast) -> prefetch 1268 return foldMemRefCast(*this); 1269 } 1270 1271 //===----------------------------------------------------------------------===// 1272 // ReinterpretCastOp 1273 //===----------------------------------------------------------------------===// 1274 1275 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1276 /// `staticSizes` and `staticStrides` are automatically filled with 1277 /// source-memref-rank sentinel values that encode dynamic entries. 1278 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1279 MemRefType resultType, Value source, 1280 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1281 ArrayRef<OpFoldResult> strides, 1282 ArrayRef<NamedAttribute> attrs) { 1283 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1284 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1285 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1286 ShapedType::kDynamicStrideOrOffset); 1287 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1288 ShapedType::kDynamicSize); 1289 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1290 ShapedType::kDynamicStrideOrOffset); 1291 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1292 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1293 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1294 result.addAttributes(attrs); 1295 } 1296 1297 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1298 MemRefType resultType, Value source, 1299 int64_t offset, ArrayRef<int64_t> sizes, 1300 ArrayRef<int64_t> strides, 1301 ArrayRef<NamedAttribute> attrs) { 1302 SmallVector<OpFoldResult> sizeValues = 1303 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1304 return b.getI64IntegerAttr(v); 1305 })); 1306 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1307 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1308 return b.getI64IntegerAttr(v); 1309 })); 1310 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1311 strideValues, attrs); 1312 } 1313 1314 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1315 MemRefType resultType, Value source, Value offset, 1316 ValueRange sizes, ValueRange strides, 1317 ArrayRef<NamedAttribute> attrs) { 1318 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1319 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1320 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1321 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1322 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1323 } 1324 1325 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1326 // completed automatically, like we have for subview and subtensor. 1327 static LogicalResult verify(ReinterpretCastOp op) { 1328 // The source and result memrefs should be in the same memory space. 1329 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1330 auto resultType = op.getType().cast<MemRefType>(); 1331 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1332 return op.emitError("different memory spaces specified for source type ") 1333 << srcType << " and result memref type " << resultType; 1334 if (srcType.getElementType() != resultType.getElementType()) 1335 return op.emitError("different element types specified for source type ") 1336 << srcType << " and result memref type " << resultType; 1337 1338 // Match sizes in result memref type and in static_sizes attribute. 1339 for (auto &en : 1340 llvm::enumerate(llvm::zip(resultType.getShape(), 1341 extractFromI64ArrayAttr(op.static_sizes())))) { 1342 int64_t resultSize = std::get<0>(en.value()); 1343 int64_t expectedSize = std::get<1>(en.value()); 1344 if (resultSize != expectedSize) 1345 return op.emitError("expected result type with size = ") 1346 << expectedSize << " instead of " << resultSize 1347 << " in dim = " << en.index(); 1348 } 1349 1350 // Match offset and strides in static_offset and static_strides attributes if 1351 // result memref type has an affine map specified. 1352 if (!resultType.getAffineMaps().empty()) { 1353 int64_t resultOffset; 1354 SmallVector<int64_t, 4> resultStrides; 1355 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1356 return failure(); 1357 1358 // Match offset in result memref type and in static_offsets attribute. 1359 int64_t expectedOffset = 1360 extractFromI64ArrayAttr(op.static_offsets()).front(); 1361 if (resultOffset != expectedOffset) 1362 return op.emitError("expected result type with offset = ") 1363 << resultOffset << " instead of " << expectedOffset; 1364 1365 // Match strides in result memref type and in static_strides attribute. 1366 for (auto &en : llvm::enumerate(llvm::zip( 1367 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1368 int64_t resultStride = std::get<0>(en.value()); 1369 int64_t expectedStride = std::get<1>(en.value()); 1370 if (resultStride != expectedStride) 1371 return op.emitError("expected result type with stride = ") 1372 << expectedStride << " instead of " << resultStride 1373 << " in dim = " << en.index(); 1374 } 1375 } 1376 return success(); 1377 } 1378 1379 //===----------------------------------------------------------------------===// 1380 // ReshapeOp 1381 //===----------------------------------------------------------------------===// 1382 1383 static LogicalResult verify(ReshapeOp op) { 1384 Type operandType = op.source().getType(); 1385 Type resultType = op.result().getType(); 1386 1387 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1388 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1389 if (operandElementType != resultElementType) 1390 return op.emitOpError("element types of source and destination memref " 1391 "types should be the same"); 1392 1393 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1394 if (!operandMemRefType.getAffineMaps().empty()) 1395 return op.emitOpError( 1396 "source memref type should have identity affine map"); 1397 1398 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1399 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1400 if (resultMemRefType) { 1401 if (!resultMemRefType.getAffineMaps().empty()) 1402 return op.emitOpError( 1403 "result memref type should have identity affine map"); 1404 if (shapeSize == ShapedType::kDynamicSize) 1405 return op.emitOpError("cannot use shape operand with dynamic length to " 1406 "reshape to statically-ranked memref type"); 1407 if (shapeSize != resultMemRefType.getRank()) 1408 return op.emitOpError( 1409 "length of shape operand differs from the result's memref rank"); 1410 } 1411 return success(); 1412 } 1413 1414 //===----------------------------------------------------------------------===// 1415 // StoreOp 1416 //===----------------------------------------------------------------------===// 1417 1418 static LogicalResult verify(StoreOp op) { 1419 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1420 return op.emitOpError("store index operand count not equal to memref rank"); 1421 1422 return success(); 1423 } 1424 1425 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1426 SmallVectorImpl<OpFoldResult> &results) { 1427 /// store(memrefcast) -> store 1428 return foldMemRefCast(*this); 1429 } 1430 1431 //===----------------------------------------------------------------------===// 1432 // SubViewOp 1433 //===----------------------------------------------------------------------===// 1434 1435 namespace { 1436 /// Helpers to write more idiomatic operations. 1437 namespace saturated_arith { 1438 struct Wrapper { 1439 explicit Wrapper(int64_t v) : v(v) {} 1440 operator int64_t() { return v; } 1441 int64_t v; 1442 }; 1443 Wrapper operator+(Wrapper a, int64_t b) { 1444 if (ShapedType::isDynamicStrideOrOffset(a) || 1445 ShapedType::isDynamicStrideOrOffset(b)) 1446 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1447 return Wrapper(a.v + b); 1448 } 1449 Wrapper operator*(Wrapper a, int64_t b) { 1450 if (ShapedType::isDynamicStrideOrOffset(a) || 1451 ShapedType::isDynamicStrideOrOffset(b)) 1452 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1453 return Wrapper(a.v * b); 1454 } 1455 } // end namespace saturated_arith 1456 } // end namespace 1457 1458 /// A subview result type can be fully inferred from the source type and the 1459 /// static representation of offsets, sizes and strides. Special sentinels 1460 /// encode the dynamic case. 1461 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1462 ArrayRef<int64_t> leadingStaticOffsets, 1463 ArrayRef<int64_t> leadingStaticSizes, 1464 ArrayRef<int64_t> leadingStaticStrides) { 1465 // A subview may specify only a leading subset of offset/sizes/strides in 1466 // which case we complete with offset=0, sizes from memref type and strides=1. 1467 unsigned rank = sourceMemRefType.getRank(); 1468 assert(leadingStaticOffsets.size() <= rank && 1469 "unexpected leadingStaticOffsets overflow"); 1470 assert(leadingStaticSizes.size() <= rank && 1471 "unexpected leadingStaticSizes overflow"); 1472 assert(leadingStaticStrides.size() <= rank && 1473 "unexpected leadingStaticStrides overflow"); 1474 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1475 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1476 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1477 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1478 unsigned numTrailingSizes = rank - staticSizes.size(); 1479 unsigned numTrailingStrides = rank - staticStrides.size(); 1480 staticOffsets.append(numTrailingOffsets, 0); 1481 llvm::append_range(staticSizes, 1482 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1483 staticStrides.append(numTrailingStrides, 1); 1484 1485 // Extract source offset and strides. 1486 int64_t sourceOffset; 1487 SmallVector<int64_t, 4> sourceStrides; 1488 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1489 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1490 (void)res; 1491 1492 // Compute target offset whose value is: 1493 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1494 int64_t targetOffset = sourceOffset; 1495 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1496 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1497 using namespace saturated_arith; 1498 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1499 } 1500 1501 // Compute target stride whose value is: 1502 // `sourceStrides_i * staticStrides_i`. 1503 SmallVector<int64_t, 4> targetStrides; 1504 targetStrides.reserve(staticOffsets.size()); 1505 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1506 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1507 using namespace saturated_arith; 1508 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1509 } 1510 1511 // The type is now known. 1512 return MemRefType::get( 1513 staticSizes, sourceMemRefType.getElementType(), 1514 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1515 sourceMemRefType.getContext()), 1516 sourceMemRefType.getMemorySpace()); 1517 } 1518 1519 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1520 ArrayRef<OpFoldResult> leadingStaticOffsets, 1521 ArrayRef<OpFoldResult> leadingStaticSizes, 1522 ArrayRef<OpFoldResult> leadingStaticStrides) { 1523 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1524 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1525 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1526 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1527 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1528 ShapedType::kDynamicSize); 1529 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1530 staticStrides, ShapedType::kDynamicStrideOrOffset); 1531 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1532 staticSizes, staticStrides) 1533 .cast<MemRefType>(); 1534 } 1535 1536 Type SubViewOp::inferRankReducedResultType( 1537 unsigned resultRank, MemRefType sourceRankedTensorType, 1538 ArrayRef<int64_t> leadingStaticOffsets, 1539 ArrayRef<int64_t> leadingStaticSizes, 1540 ArrayRef<int64_t> leadingStaticStrides) { 1541 auto inferredType = 1542 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1543 leadingStaticSizes, leadingStaticStrides) 1544 .cast<MemRefType>(); 1545 assert(inferredType.getRank() >= resultRank && "expected "); 1546 int rankDiff = inferredType.getRank() - resultRank; 1547 if (rankDiff > 0) { 1548 auto shape = inferredType.getShape(); 1549 llvm::SmallDenseSet<unsigned> dimsToProject; 1550 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1551 SmallVector<int64_t> projectedShape; 1552 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1553 if (!dimsToProject.contains(pos)) 1554 projectedShape.push_back(shape[pos]); 1555 1556 AffineMap map; 1557 auto maps = inferredType.getAffineMaps(); 1558 if (!maps.empty() && maps.front()) 1559 map = getProjectedMap(maps.front(), dimsToProject); 1560 inferredType = 1561 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1562 inferredType.getMemorySpace()); 1563 } 1564 return inferredType; 1565 } 1566 1567 Type SubViewOp::inferRankReducedResultType( 1568 unsigned resultRank, MemRefType sourceRankedTensorType, 1569 ArrayRef<OpFoldResult> leadingStaticOffsets, 1570 ArrayRef<OpFoldResult> leadingStaticSizes, 1571 ArrayRef<OpFoldResult> leadingStaticStrides) { 1572 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1573 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1574 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1575 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1576 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1577 ShapedType::kDynamicSize); 1578 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1579 staticStrides, ShapedType::kDynamicStrideOrOffset); 1580 return SubViewOp::inferRankReducedResultType( 1581 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1582 staticStrides); 1583 } 1584 // Build a SubViewOp with mixed static and dynamic entries and custom result 1585 // type. If the type passed is nullptr, it is inferred. 1586 void SubViewOp::build(OpBuilder &b, OperationState &result, 1587 MemRefType resultType, Value source, 1588 ArrayRef<OpFoldResult> offsets, 1589 ArrayRef<OpFoldResult> sizes, 1590 ArrayRef<OpFoldResult> strides, 1591 ArrayRef<NamedAttribute> attrs) { 1592 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1593 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1594 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1595 ShapedType::kDynamicStrideOrOffset); 1596 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1597 ShapedType::kDynamicSize); 1598 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1599 ShapedType::kDynamicStrideOrOffset); 1600 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1601 // Structuring implementation this way avoids duplication between builders. 1602 if (!resultType) { 1603 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1604 staticSizes, staticStrides) 1605 .cast<MemRefType>(); 1606 } 1607 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1608 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1609 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1610 result.addAttributes(attrs); 1611 } 1612 1613 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1614 // type. 1615 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1616 ArrayRef<OpFoldResult> offsets, 1617 ArrayRef<OpFoldResult> sizes, 1618 ArrayRef<OpFoldResult> strides, 1619 ArrayRef<NamedAttribute> attrs) { 1620 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1621 } 1622 1623 // Build a SubViewOp with static entries and inferred result type. 1624 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1625 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1626 ArrayRef<int64_t> strides, 1627 ArrayRef<NamedAttribute> attrs) { 1628 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1629 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1630 return b.getI64IntegerAttr(v); 1631 })); 1632 SmallVector<OpFoldResult> sizeValues = 1633 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1634 return b.getI64IntegerAttr(v); 1635 })); 1636 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1637 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1638 return b.getI64IntegerAttr(v); 1639 })); 1640 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1641 } 1642 1643 // Build a SubViewOp with dynamic entries and custom result type. If the 1644 // type passed is nullptr, it is inferred. 1645 void SubViewOp::build(OpBuilder &b, OperationState &result, 1646 MemRefType resultType, Value source, 1647 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1648 ArrayRef<int64_t> strides, 1649 ArrayRef<NamedAttribute> attrs) { 1650 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1651 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1652 return b.getI64IntegerAttr(v); 1653 })); 1654 SmallVector<OpFoldResult> sizeValues = 1655 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1656 return b.getI64IntegerAttr(v); 1657 })); 1658 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1659 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1660 return b.getI64IntegerAttr(v); 1661 })); 1662 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1663 attrs); 1664 } 1665 1666 // Build a SubViewOp with dynamic entries and custom result type. If the type 1667 // passed is nullptr, it is inferred. 1668 void SubViewOp::build(OpBuilder &b, OperationState &result, 1669 MemRefType resultType, Value source, ValueRange offsets, 1670 ValueRange sizes, ValueRange strides, 1671 ArrayRef<NamedAttribute> attrs) { 1672 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1673 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1674 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1675 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1676 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1677 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1678 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1679 } 1680 1681 // Build a SubViewOp with dynamic entries and inferred result type. 1682 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1683 ValueRange offsets, ValueRange sizes, ValueRange strides, 1684 ArrayRef<NamedAttribute> attrs) { 1685 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1686 } 1687 1688 /// For ViewLikeOpInterface. 1689 Value SubViewOp::getViewSource() { return source(); } 1690 1691 enum SubViewVerificationResult { 1692 Success, 1693 RankTooLarge, 1694 SizeMismatch, 1695 ElemTypeMismatch, 1696 MemSpaceMismatch, 1697 AffineMapMismatch 1698 }; 1699 1700 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1701 /// This function is slight variant of `is subsequence` algorithm where 1702 /// not matching dimension must be 1. 1703 static SubViewVerificationResult 1704 isRankReducedType(Type originalType, Type candidateReducedType, 1705 std::string *errMsg = nullptr) { 1706 if (originalType == candidateReducedType) 1707 return SubViewVerificationResult::Success; 1708 if (!originalType.isa<MemRefType>()) 1709 return SubViewVerificationResult::Success; 1710 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1711 return SubViewVerificationResult::Success; 1712 1713 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1714 ShapedType candidateReducedShapedType = 1715 candidateReducedType.cast<ShapedType>(); 1716 1717 // Rank and size logic is valid for all ShapedTypes. 1718 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1719 ArrayRef<int64_t> candidateReducedShape = 1720 candidateReducedShapedType.getShape(); 1721 unsigned originalRank = originalShape.size(), 1722 candidateReducedRank = candidateReducedShape.size(); 1723 if (candidateReducedRank > originalRank) 1724 return SubViewVerificationResult::RankTooLarge; 1725 1726 auto optionalUnusedDimsMask = 1727 computeRankReductionMask(originalShape, candidateReducedShape); 1728 1729 // Sizes cannot be matched in case empty vector is returned. 1730 if (!optionalUnusedDimsMask.hasValue()) 1731 return SubViewVerificationResult::SizeMismatch; 1732 1733 if (originalShapedType.getElementType() != 1734 candidateReducedShapedType.getElementType()) 1735 return SubViewVerificationResult::ElemTypeMismatch; 1736 1737 // Strided layout logic is relevant for MemRefType only. 1738 MemRefType original = originalType.cast<MemRefType>(); 1739 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1740 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1741 return SubViewVerificationResult::MemSpaceMismatch; 1742 1743 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1744 auto inferredType = 1745 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1746 AffineMap candidateLayout; 1747 if (candidateReduced.getAffineMaps().empty()) 1748 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1749 else 1750 candidateLayout = candidateReduced.getAffineMaps().front(); 1751 assert(inferredType.getNumResults() == 1 && 1752 candidateLayout.getNumResults() == 1); 1753 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1754 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1755 if (errMsg) { 1756 llvm::raw_string_ostream os(*errMsg); 1757 os << "inferred type: " << inferredType; 1758 } 1759 return SubViewVerificationResult::AffineMapMismatch; 1760 } 1761 // Check that the difference of the affine maps simplifies to 0. 1762 AffineExpr diffExpr = 1763 inferredType.getResult(0) - candidateLayout.getResult(0); 1764 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1765 inferredType.getNumSymbols()); 1766 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1767 if (!(cst && cst.getValue() == 0)) { 1768 if (errMsg) { 1769 llvm::raw_string_ostream os(*errMsg); 1770 os << "inferred type: " << inferredType; 1771 } 1772 return SubViewVerificationResult::AffineMapMismatch; 1773 } 1774 return SubViewVerificationResult::Success; 1775 } 1776 1777 template <typename OpTy> 1778 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1779 OpTy op, Type expectedType, 1780 StringRef errMsg = "") { 1781 auto memrefType = expectedType.cast<ShapedType>(); 1782 switch (result) { 1783 case SubViewVerificationResult::Success: 1784 return success(); 1785 case SubViewVerificationResult::RankTooLarge: 1786 return op.emitError("expected result rank to be smaller or equal to ") 1787 << "the source rank. " << errMsg; 1788 case SubViewVerificationResult::SizeMismatch: 1789 return op.emitError("expected result type to be ") 1790 << expectedType 1791 << " or a rank-reduced version. (mismatch of result sizes) " 1792 << errMsg; 1793 case SubViewVerificationResult::ElemTypeMismatch: 1794 return op.emitError("expected result element type to be ") 1795 << memrefType.getElementType() << errMsg; 1796 case SubViewVerificationResult::MemSpaceMismatch: 1797 return op.emitError("expected result and source memory spaces to match.") 1798 << errMsg; 1799 case SubViewVerificationResult::AffineMapMismatch: 1800 return op.emitError("expected result type to be ") 1801 << expectedType 1802 << " or a rank-reduced version. (mismatch of result affine map) " 1803 << errMsg; 1804 } 1805 llvm_unreachable("unexpected subview verification result"); 1806 } 1807 1808 /// Verifier for SubViewOp. 1809 static LogicalResult verify(SubViewOp op) { 1810 MemRefType baseType = op.getSourceType(); 1811 MemRefType subViewType = op.getType(); 1812 1813 // The base memref and the view memref should be in the same memory space. 1814 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1815 return op.emitError("different memory spaces specified for base memref " 1816 "type ") 1817 << baseType << " and subview memref type " << subViewType; 1818 1819 // Verify that the base memref type has a strided layout map. 1820 if (!isStrided(baseType)) 1821 return op.emitError("base type ") << baseType << " is not strided"; 1822 1823 // Verify result type against inferred type. 1824 auto expectedType = SubViewOp::inferResultType( 1825 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1826 extractFromI64ArrayAttr(op.static_sizes()), 1827 extractFromI64ArrayAttr(op.static_strides())); 1828 1829 std::string errMsg; 1830 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1831 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1832 } 1833 1834 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1835 return os << "range " << range.offset << ":" << range.size << ":" 1836 << range.stride; 1837 } 1838 1839 /// Return the list of Range (i.e. offset, size, stride). Each Range 1840 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1841 /// with `b` at location `loc`. 1842 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1843 OpBuilder &b, Location loc) { 1844 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1845 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1846 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1847 SmallVector<Range, 8> res; 1848 unsigned rank = ranks[0]; 1849 res.reserve(rank); 1850 for (unsigned idx = 0; idx < rank; ++idx) { 1851 Value offset = 1852 op.isDynamicOffset(idx) 1853 ? op.getDynamicOffset(idx) 1854 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1855 Value size = op.isDynamicSize(idx) 1856 ? op.getDynamicSize(idx) 1857 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1858 Value stride = 1859 op.isDynamicStride(idx) 1860 ? op.getDynamicStride(idx) 1861 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1862 res.emplace_back(Range{offset, size, stride}); 1863 } 1864 return res; 1865 } 1866 1867 /// Infer the canonical type of the result of a subview operation. Returns a 1868 /// type with rank `resultRank` that is either the rank of the rank-reduced 1869 /// type, or the non-rank-reduced type. 1870 static MemRefType 1871 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1872 ArrayRef<OpFoldResult> mixedOffsets, 1873 ArrayRef<OpFoldResult> mixedSizes, 1874 ArrayRef<OpFoldResult> mixedStrides) { 1875 auto resultType = 1876 SubViewOp::inferRankReducedResultType( 1877 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1878 .cast<MemRefType>(); 1879 if (resultType.getRank() != resultRank) { 1880 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1881 mixedSizes, mixedStrides) 1882 .cast<MemRefType>(); 1883 } 1884 return resultType; 1885 } 1886 1887 namespace { 1888 /// Pattern to rewrite a subview op with MemRefCast arguments. 1889 /// This essentially pushes memref.cast past its consuming subview when 1890 /// `canFoldIntoConsumerOp` is true. 1891 /// 1892 /// Example: 1893 /// ``` 1894 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1895 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1896 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1897 /// ``` 1898 /// is rewritten into: 1899 /// ``` 1900 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1901 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1902 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1903 /// ``` 1904 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1905 public: 1906 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1907 1908 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1909 PatternRewriter &rewriter) const override { 1910 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1911 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1912 return matchPattern(operand, matchConstantIndex()); 1913 })) 1914 return failure(); 1915 1916 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1917 if (!castOp) 1918 return failure(); 1919 1920 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1921 return failure(); 1922 1923 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1924 /// the cast source operand type and the SubViewOp static information. This 1925 /// is the resulting type if the MemRefCastOp were folded. 1926 auto resultType = getCanonicalSubViewResultType( 1927 subViewOp.getType().getRank(), 1928 castOp.source().getType().cast<MemRefType>(), 1929 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1930 subViewOp.getMixedStrides()); 1931 Value newSubView = rewriter.create<SubViewOp>( 1932 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1933 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1934 subViewOp.static_sizes(), subViewOp.static_strides()); 1935 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1936 newSubView); 1937 return success(); 1938 } 1939 }; 1940 } // namespace 1941 1942 /// Return the canonical type of the result of a subview. 1943 struct SubViewReturnTypeCanonicalizer { 1944 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 1945 ArrayRef<OpFoldResult> mixedSizes, 1946 ArrayRef<OpFoldResult> mixedStrides) { 1947 return getCanonicalSubViewResultType(op.getType().getRank(), 1948 op.getSourceType(), mixedOffsets, 1949 mixedSizes, mixedStrides); 1950 } 1951 }; 1952 1953 /// A canonicalizer wrapper to replace SubViewOps. 1954 struct SubViewCanonicalizer { 1955 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1956 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1957 } 1958 }; 1959 1960 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1961 MLIRContext *context) { 1962 results 1963 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1964 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 1965 SubViewOpMemRefCastFolder>(context); 1966 } 1967 1968 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1969 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1970 auto sourceShapedType = source().getType().cast<ShapedType>(); 1971 1972 if (resultShapedType.hasStaticShape() && 1973 resultShapedType == sourceShapedType) { 1974 return getViewSource(); 1975 } 1976 1977 return {}; 1978 } 1979 1980 //===----------------------------------------------------------------------===// 1981 // TensorLoadOp 1982 //===----------------------------------------------------------------------===// 1983 1984 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1985 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1986 // Approximate alias analysis by conservatively folding only when no there 1987 // is no interleaved operation. 1988 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1989 bufferCast->getNextNode() == this->getOperation()) 1990 return bufferCast.tensor(); 1991 return {}; 1992 } 1993 1994 //===----------------------------------------------------------------------===// 1995 // TransposeOp 1996 //===----------------------------------------------------------------------===// 1997 1998 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1999 static MemRefType inferTransposeResultType(MemRefType memRefType, 2000 AffineMap permutationMap) { 2001 auto rank = memRefType.getRank(); 2002 auto originalSizes = memRefType.getShape(); 2003 // Compute permuted sizes. 2004 SmallVector<int64_t, 4> sizes(rank, 0); 2005 for (auto en : llvm::enumerate(permutationMap.getResults())) 2006 sizes[en.index()] = 2007 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2008 2009 // Compute permuted strides. 2010 int64_t offset; 2011 SmallVector<int64_t, 4> strides; 2012 auto res = getStridesAndOffset(memRefType, strides, offset); 2013 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2014 (void)res; 2015 auto map = 2016 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2017 map = permutationMap ? map.compose(permutationMap) : map; 2018 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2019 } 2020 2021 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2022 AffineMapAttr permutation, 2023 ArrayRef<NamedAttribute> attrs) { 2024 auto permutationMap = permutation.getValue(); 2025 assert(permutationMap); 2026 2027 auto memRefType = in.getType().cast<MemRefType>(); 2028 // Compute result type. 2029 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2030 2031 build(b, result, resultType, in, attrs); 2032 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2033 } 2034 2035 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2036 static void print(OpAsmPrinter &p, TransposeOp op) { 2037 p << "memref.transpose " << op.in() << " " << op.permutation(); 2038 p.printOptionalAttrDict(op->getAttrs(), 2039 {TransposeOp::getPermutationAttrName()}); 2040 p << " : " << op.in().getType() << " to " << op.getType(); 2041 } 2042 2043 static ParseResult parseTransposeOp(OpAsmParser &parser, 2044 OperationState &result) { 2045 OpAsmParser::OperandType in; 2046 AffineMap permutation; 2047 MemRefType srcType, dstType; 2048 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2049 parser.parseOptionalAttrDict(result.attributes) || 2050 parser.parseColonType(srcType) || 2051 parser.resolveOperand(in, srcType, result.operands) || 2052 parser.parseKeywordType("to", dstType) || 2053 parser.addTypeToList(dstType, result.types)) 2054 return failure(); 2055 2056 result.addAttribute(TransposeOp::getPermutationAttrName(), 2057 AffineMapAttr::get(permutation)); 2058 return success(); 2059 } 2060 2061 static LogicalResult verify(TransposeOp op) { 2062 if (!op.permutation().isPermutation()) 2063 return op.emitOpError("expected a permutation map"); 2064 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2065 return op.emitOpError( 2066 "expected a permutation map of same rank as the input"); 2067 2068 auto srcType = op.in().getType().cast<MemRefType>(); 2069 auto dstType = op.getType().cast<MemRefType>(); 2070 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2071 if (dstType != transposedType) 2072 return op.emitOpError("output type ") 2073 << dstType << " does not match transposed input type " << srcType 2074 << ", " << transposedType; 2075 return success(); 2076 } 2077 2078 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2079 if (succeeded(foldMemRefCast(*this))) 2080 return getResult(); 2081 return {}; 2082 } 2083 2084 //===----------------------------------------------------------------------===// 2085 // ViewOp 2086 //===----------------------------------------------------------------------===// 2087 2088 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2089 OpAsmParser::OperandType srcInfo; 2090 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2091 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2092 auto indexType = parser.getBuilder().getIndexType(); 2093 Type srcType, dstType; 2094 llvm::SMLoc offsetLoc; 2095 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2096 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2097 return failure(); 2098 2099 if (offsetInfo.size() != 1) 2100 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2101 2102 return failure( 2103 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2104 parser.parseOptionalAttrDict(result.attributes) || 2105 parser.parseColonType(srcType) || 2106 parser.resolveOperand(srcInfo, srcType, result.operands) || 2107 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2108 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2109 parser.parseKeywordType("to", dstType) || 2110 parser.addTypeToList(dstType, result.types)); 2111 } 2112 2113 static void print(OpAsmPrinter &p, ViewOp op) { 2114 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2115 p.printOperand(op.byte_shift()); 2116 p << "][" << op.sizes() << ']'; 2117 p.printOptionalAttrDict(op->getAttrs()); 2118 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2119 } 2120 2121 static LogicalResult verify(ViewOp op) { 2122 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2123 auto viewType = op.getType(); 2124 2125 // The base memref should have identity layout map (or none). 2126 if (baseType.getAffineMaps().size() > 1 || 2127 (baseType.getAffineMaps().size() == 1 && 2128 !baseType.getAffineMaps()[0].isIdentity())) 2129 return op.emitError("unsupported map for base memref type ") << baseType; 2130 2131 // The result memref should have identity layout map (or none). 2132 if (viewType.getAffineMaps().size() > 1 || 2133 (viewType.getAffineMaps().size() == 1 && 2134 !viewType.getAffineMaps()[0].isIdentity())) 2135 return op.emitError("unsupported map for result memref type ") << viewType; 2136 2137 // The base memref and the view memref should be in the same memory space. 2138 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2139 return op.emitError("different memory spaces specified for base memref " 2140 "type ") 2141 << baseType << " and view memref type " << viewType; 2142 2143 // Verify that we have the correct number of sizes for the result type. 2144 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2145 if (op.sizes().size() != numDynamicDims) 2146 return op.emitError("incorrect number of size operands for type ") 2147 << viewType; 2148 2149 return success(); 2150 } 2151 2152 Value ViewOp::getViewSource() { return source(); } 2153 2154 namespace { 2155 2156 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2157 using OpRewritePattern<ViewOp>::OpRewritePattern; 2158 2159 LogicalResult matchAndRewrite(ViewOp viewOp, 2160 PatternRewriter &rewriter) const override { 2161 // Return if none of the operands are constants. 2162 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2163 return matchPattern(operand, matchConstantIndex()); 2164 })) 2165 return failure(); 2166 2167 // Get result memref type. 2168 auto memrefType = viewOp.getType(); 2169 2170 // Get offset from old memref view type 'memRefType'. 2171 int64_t oldOffset; 2172 SmallVector<int64_t, 4> oldStrides; 2173 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2174 return failure(); 2175 assert(oldOffset == 0 && "Expected 0 offset"); 2176 2177 SmallVector<Value, 4> newOperands; 2178 2179 // Offset cannot be folded into result type. 2180 2181 // Fold any dynamic dim operands which are produced by a constant. 2182 SmallVector<int64_t, 4> newShapeConstants; 2183 newShapeConstants.reserve(memrefType.getRank()); 2184 2185 unsigned dynamicDimPos = 0; 2186 unsigned rank = memrefType.getRank(); 2187 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2188 int64_t dimSize = memrefType.getDimSize(dim); 2189 // If this is already static dimension, keep it. 2190 if (!ShapedType::isDynamic(dimSize)) { 2191 newShapeConstants.push_back(dimSize); 2192 continue; 2193 } 2194 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2195 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2196 // Dynamic shape dimension will be folded. 2197 newShapeConstants.push_back(constantIndexOp.getValue()); 2198 } else { 2199 // Dynamic shape dimension not folded; copy operand from old memref. 2200 newShapeConstants.push_back(dimSize); 2201 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2202 } 2203 dynamicDimPos++; 2204 } 2205 2206 // Create new memref type with constant folded dims. 2207 MemRefType newMemRefType = 2208 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2209 // Nothing new, don't fold. 2210 if (newMemRefType == memrefType) 2211 return failure(); 2212 2213 // Create new ViewOp. 2214 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2215 viewOp.getOperand(0), 2216 viewOp.byte_shift(), newOperands); 2217 // Insert a cast so we have the same type as the old memref type. 2218 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2219 return success(); 2220 } 2221 }; 2222 2223 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2224 using OpRewritePattern<ViewOp>::OpRewritePattern; 2225 2226 LogicalResult matchAndRewrite(ViewOp viewOp, 2227 PatternRewriter &rewriter) const override { 2228 Value memrefOperand = viewOp.getOperand(0); 2229 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2230 if (!memrefCastOp) 2231 return failure(); 2232 Value allocOperand = memrefCastOp.getOperand(); 2233 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2234 if (!allocOp) 2235 return failure(); 2236 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2237 viewOp.byte_shift(), viewOp.sizes()); 2238 return success(); 2239 } 2240 }; 2241 2242 } // end anonymous namespace 2243 2244 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2245 MLIRContext *context) { 2246 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2247 } 2248 2249 //===----------------------------------------------------------------------===// 2250 // TableGen'd op method definitions 2251 //===----------------------------------------------------------------------===// 2252 2253 #define GET_OP_CLASSES 2254 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2255