1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/ViewLikeInterface.h" 22 #include "llvm/ADT/STLExtras.h" 23 24 using namespace mlir; 25 using namespace mlir::memref; 26 27 /// Materialize a single constant operation from a given attribute value with 28 /// the desired resultant type. 29 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 30 Attribute value, Type type, 31 Location loc) { 32 return builder.create<mlir::ConstantOp>(loc, type, value); 33 } 34 35 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 36 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 37 return llvm::to_vector<4>( 38 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 39 return a.cast<IntegerAttr>().getInt(); 40 })); 41 } 42 43 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 44 /// it is a Value or into `staticVec` if it is an IntegerAttr. 45 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 47 /// come from an AttrSizedOperandSegments trait. 48 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 49 SmallVectorImpl<Value> &dynamicVec, 50 SmallVectorImpl<int64_t> &staticVec, 51 int64_t sentinel) { 52 if (auto v = ofr.dyn_cast<Value>()) { 53 dynamicVec.push_back(v); 54 staticVec.push_back(sentinel); 55 return; 56 } 57 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 58 staticVec.push_back(apInt.getSExtValue()); 59 } 60 61 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 62 SmallVectorImpl<Value> &dynamicVec, 63 SmallVectorImpl<int64_t> &staticVec, 64 int64_t sentinel) { 65 for (auto ofr : ofrs) 66 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Common canonicalization pattern support logic 71 //===----------------------------------------------------------------------===// 72 73 /// This is a common class used for patterns of the form 74 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 75 /// into the root operation directly. 76 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 77 bool folded = false; 78 for (OpOperand &operand : op->getOpOperands()) { 79 auto cast = operand.get().getDefiningOp<CastOp>(); 80 if (cast && operand.get() != inner && 81 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 82 operand.set(cast.getOperand()); 83 folded = true; 84 } 85 } 86 return success(folded); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Helpers for GlobalOp 91 //===----------------------------------------------------------------------===// 92 93 static Type getTensorTypeFromMemRefType(Type type) { 94 if (auto memref = type.dyn_cast<MemRefType>()) 95 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 96 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 97 return UnrankedTensorType::get(memref.getElementType()); 98 return NoneType::get(type.getContext()); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // AllocOp / AllocaOp 103 //===----------------------------------------------------------------------===// 104 105 template <typename AllocLikeOp> 106 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 107 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 108 "applies to only alloc or alloca"); 109 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 110 if (!memRefType) 111 return op.emitOpError("result must be a memref"); 112 113 if (static_cast<int64_t>(op.dynamicSizes().size()) != 114 memRefType.getNumDynamicDims()) 115 return op.emitOpError("dimension operand count does not equal memref " 116 "dynamic dimension count"); 117 118 unsigned numSymbols = 0; 119 if (!memRefType.getAffineMaps().empty()) 120 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 121 if (op.symbolOperands().size() != numSymbols) 122 return op.emitOpError( 123 "symbol operand count does not equal memref symbol count"); 124 125 return success(); 126 } 127 128 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 129 130 static LogicalResult verify(AllocaOp op) { 131 // An alloca op needs to have an ancestor with an allocation scope trait. 132 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 133 return op.emitOpError( 134 "requires an ancestor op with AutomaticAllocationScope trait"); 135 136 return verifyAllocLikeOp(op); 137 } 138 139 namespace { 140 /// Fold constant dimensions into an alloc like operation. 141 template <typename AllocLikeOp> 142 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 143 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 144 145 LogicalResult matchAndRewrite(AllocLikeOp alloc, 146 PatternRewriter &rewriter) const override { 147 // Check to see if any dimensions operands are constants. If so, we can 148 // substitute and drop them. 149 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 150 return matchPattern(operand, matchConstantIndex()); 151 })) 152 return failure(); 153 154 auto memrefType = alloc.getType(); 155 156 // Ok, we have one or more constant operands. Collect the non-constant ones 157 // and keep track of the resultant memref type to build. 158 SmallVector<int64_t, 4> newShapeConstants; 159 newShapeConstants.reserve(memrefType.getRank()); 160 SmallVector<Value, 4> newOperands; 161 162 unsigned dynamicDimPos = 0; 163 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 164 int64_t dimSize = memrefType.getDimSize(dim); 165 // If this is already static dimension, keep it. 166 if (dimSize != -1) { 167 newShapeConstants.push_back(dimSize); 168 continue; 169 } 170 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 171 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 172 // Dynamic shape dimension will be folded. 173 newShapeConstants.push_back(constantIndexOp.getValue()); 174 } else { 175 // Dynamic shape dimension not folded; copy operand from old memref. 176 newShapeConstants.push_back(-1); 177 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 178 } 179 dynamicDimPos++; 180 } 181 182 // Create new memref type (which will have fewer dynamic dimensions). 183 MemRefType newMemRefType = 184 MemRefType::Builder(memrefType).setShape(newShapeConstants); 185 assert(static_cast<int64_t>(newOperands.size()) == 186 newMemRefType.getNumDynamicDims()); 187 188 // Create and insert the alloc op for the new memref. 189 auto newAlloc = rewriter.create<AllocLikeOp>( 190 alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); 191 // Insert a cast so we have the same type as the old alloc. 192 auto resultCast = 193 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 194 195 rewriter.replaceOp(alloc, {resultCast}); 196 return success(); 197 } 198 }; 199 200 /// Fold alloc operations with no users or only store and dealloc uses. 201 template <typename T> 202 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 203 using OpRewritePattern<T>::OpRewritePattern; 204 205 LogicalResult matchAndRewrite(T alloc, 206 PatternRewriter &rewriter) const override { 207 if (llvm::any_of(alloc->getUsers(), [](Operation *op) { 208 return !isa<StoreOp, DeallocOp>(op); 209 })) 210 return failure(); 211 212 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 213 rewriter.eraseOp(user); 214 215 rewriter.eraseOp(alloc); 216 return success(); 217 } 218 }; 219 } // end anonymous namespace. 220 221 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 222 MLIRContext *context) { 223 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 224 } 225 226 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 227 MLIRContext *context) { 228 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 229 context); 230 } 231 232 //===----------------------------------------------------------------------===// 233 // AllocaScopeOp 234 //===----------------------------------------------------------------------===// 235 236 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 237 bool printBlockTerminators = false; 238 239 p << AllocaScopeOp::getOperationName() << " "; 240 if (!op.results().empty()) { 241 p << " -> (" << op.getResultTypes() << ")"; 242 printBlockTerminators = true; 243 } 244 p.printRegion(op.bodyRegion(), 245 /*printEntryBlockArgs=*/false, 246 /*printBlockTerminators=*/printBlockTerminators); 247 p.printOptionalAttrDict(op->getAttrs()); 248 } 249 250 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 251 OperationState &result) { 252 // Create a region for the body. 253 result.regions.reserve(1); 254 Region *bodyRegion = result.addRegion(); 255 256 // Parse optional results type list. 257 if (parser.parseOptionalArrowTypeList(result.types)) 258 return failure(); 259 260 // Parse the body region. 261 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 262 return failure(); 263 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 264 result.location); 265 266 // Parse the optional attribute list. 267 if (parser.parseOptionalAttrDict(result.attributes)) 268 return failure(); 269 270 return success(); 271 } 272 273 static LogicalResult verify(AllocaScopeOp op) { 274 if (failed(RegionBranchOpInterface::verifyTypes(op))) 275 return failure(); 276 277 return success(); 278 } 279 280 void AllocaScopeOp::getSuccessorRegions( 281 Optional<unsigned> index, ArrayRef<Attribute> operands, 282 SmallVectorImpl<RegionSuccessor> ®ions) { 283 if (index.hasValue()) { 284 regions.push_back(RegionSuccessor(getResults())); 285 return; 286 } 287 288 regions.push_back(RegionSuccessor(&bodyRegion())); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // AssumeAlignmentOp 293 //===----------------------------------------------------------------------===// 294 295 static LogicalResult verify(AssumeAlignmentOp op) { 296 unsigned alignment = op.alignment(); 297 if (!llvm::isPowerOf2_32(alignment)) 298 return op.emitOpError("alignment must be power of 2"); 299 return success(); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // BufferCastOp 304 //===----------------------------------------------------------------------===// 305 306 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 307 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 308 if (tensorLoad.memref().getType() == getType()) 309 return tensorLoad.memref(); 310 return {}; 311 } 312 313 namespace { 314 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 315 struct BufferCast : public OpRewritePattern<BufferCastOp> { 316 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 317 318 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 319 PatternRewriter &rewriter) const final { 320 auto tensorCastOperand = 321 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 322 if (!tensorCastOperand) 323 return failure(); 324 auto srcTensorType = 325 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 326 if (!srcTensorType) 327 return failure(); 328 auto memrefType = MemRefType::get(srcTensorType.getShape(), 329 srcTensorType.getElementType()); 330 Value memref = rewriter.create<BufferCastOp>( 331 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 332 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 333 memref); 334 return success(); 335 } 336 }; 337 338 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 339 /// type mismatches prevent `BufferCastOp::fold` to kick in. 340 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 341 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 342 343 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 344 PatternRewriter &rewriter) const final { 345 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 346 // Bail unless we have a tensor_load + memref.buffer_cast with different 347 // types. `BufferCastOp::fold` handles the same type case. 348 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 349 return failure(); 350 // If types are not cast-compatible, bail. 351 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 352 bufferCast.getType())) 353 return failure(); 354 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 355 tensorLoad.memref()); 356 return success(); 357 } 358 }; 359 360 } // namespace 361 362 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 363 MLIRContext *context) { 364 results.add<BufferCast, TensorLoadToMemRef>(context); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // CastOp 369 //===----------------------------------------------------------------------===// 370 371 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 372 /// source memref. This is useful to to fold a memref.cast into a consuming op 373 /// and implement canonicalization patterns for ops in different dialects that 374 /// may consume the results of memref.cast operations. Such foldable memref.cast 375 /// operations are typically inserted as `view` and `subview` ops are 376 /// canonicalized, to preserve the type compatibility of their uses. 377 /// 378 /// Returns true when all conditions are met: 379 /// 1. source and result are ranked memrefs with strided semantics and same 380 /// element type and rank. 381 /// 2. each of the source's size, offset or stride has more static information 382 /// than the corresponding result's size, offset or stride. 383 /// 384 /// Example 1: 385 /// ```mlir 386 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 387 /// %2 = consumer %1 ... : memref<?x?xf32> ... 388 /// ``` 389 /// 390 /// may fold into: 391 /// 392 /// ```mlir 393 /// %2 = consumer %0 ... : memref<8x16xf32> ... 394 /// ``` 395 /// 396 /// Example 2: 397 /// ``` 398 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 399 /// to memref<?x?xf32> 400 /// consumer %1 : memref<?x?xf32> ... 401 /// ``` 402 /// 403 /// may fold into: 404 /// 405 /// ``` 406 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 407 /// ``` 408 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 409 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 410 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 411 412 // Requires ranked MemRefType. 413 if (!sourceType || !resultType) 414 return false; 415 416 // Requires same elemental type. 417 if (sourceType.getElementType() != resultType.getElementType()) 418 return false; 419 420 // Requires same rank. 421 if (sourceType.getRank() != resultType.getRank()) 422 return false; 423 424 // Only fold casts between strided memref forms. 425 int64_t sourceOffset, resultOffset; 426 SmallVector<int64_t, 4> sourceStrides, resultStrides; 427 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 428 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 429 return false; 430 431 // If cast is towards more static sizes along any dimension, don't fold. 432 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 433 auto ss = std::get<0>(it), st = std::get<1>(it); 434 if (ss != st) 435 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 436 return false; 437 } 438 439 // If cast is towards more static offset along any dimension, don't fold. 440 if (sourceOffset != resultOffset) 441 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 442 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 443 return false; 444 445 // If cast is towards more static strides along any dimension, don't fold. 446 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 447 auto ss = std::get<0>(it), st = std::get<1>(it); 448 if (ss != st) 449 if (MemRefType::isDynamicStrideOrOffset(ss) && 450 !MemRefType::isDynamicStrideOrOffset(st)) 451 return false; 452 } 453 454 return true; 455 } 456 457 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 458 if (inputs.size() != 1 || outputs.size() != 1) 459 return false; 460 Type a = inputs.front(), b = outputs.front(); 461 auto aT = a.dyn_cast<MemRefType>(); 462 auto bT = b.dyn_cast<MemRefType>(); 463 464 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 465 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 466 467 if (aT && bT) { 468 if (aT.getElementType() != bT.getElementType()) 469 return false; 470 if (aT.getAffineMaps() != bT.getAffineMaps()) { 471 int64_t aOffset, bOffset; 472 SmallVector<int64_t, 4> aStrides, bStrides; 473 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 474 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 475 aStrides.size() != bStrides.size()) 476 return false; 477 478 // Strides along a dimension/offset are compatible if the value in the 479 // source memref is static and the value in the target memref is the 480 // same. They are also compatible if either one is dynamic (see 481 // description of MemRefCastOp for details). 482 auto checkCompatible = [](int64_t a, int64_t b) { 483 return (a == MemRefType::getDynamicStrideOrOffset() || 484 b == MemRefType::getDynamicStrideOrOffset() || a == b); 485 }; 486 if (!checkCompatible(aOffset, bOffset)) 487 return false; 488 for (auto aStride : enumerate(aStrides)) 489 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 490 return false; 491 } 492 if (aT.getMemorySpace() != bT.getMemorySpace()) 493 return false; 494 495 // They must have the same rank, and any specified dimensions must match. 496 if (aT.getRank() != bT.getRank()) 497 return false; 498 499 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 500 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 501 if (aDim != -1 && bDim != -1 && aDim != bDim) 502 return false; 503 } 504 return true; 505 } else { 506 if (!aT && !uaT) 507 return false; 508 if (!bT && !ubT) 509 return false; 510 // Unranked to unranked casting is unsupported 511 if (uaT && ubT) 512 return false; 513 514 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 515 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 516 if (aEltType != bEltType) 517 return false; 518 519 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 520 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 521 if (aMemSpace != bMemSpace) 522 return false; 523 524 return true; 525 } 526 527 return false; 528 } 529 530 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 531 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // CloneOp 536 //===----------------------------------------------------------------------===// 537 538 void CloneOp::getEffects( 539 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 540 &effects) { 541 effects.emplace_back(MemoryEffects::Read::get(), input(), 542 SideEffects::DefaultResource::get()); 543 effects.emplace_back(MemoryEffects::Write::get(), output(), 544 SideEffects::DefaultResource::get()); 545 } 546 547 namespace { 548 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 549 /// by other Dealloc operations. 550 struct SimplifyClones : public OpRewritePattern<CloneOp> { 551 using OpRewritePattern<CloneOp>::OpRewritePattern; 552 553 LogicalResult matchAndRewrite(CloneOp cloneOp, 554 PatternRewriter &rewriter) const override { 555 if (cloneOp.use_empty()) { 556 rewriter.eraseOp(cloneOp); 557 return success(); 558 } 559 560 Value source = cloneOp.input(); 561 562 // This only finds dealloc operations for the immediate value. It should 563 // also consider aliases. That would also make the safety check below 564 // redundant. 565 Operation *cloneDeallocOp = findDealloc(cloneOp.output()); 566 Operation *sourceDeallocOp = findDealloc(source); 567 568 // If both are deallocated in the same block, their in-block lifetimes 569 // might not fully overlap, so we cannot decide which one to drop. 570 if (cloneDeallocOp && sourceDeallocOp && 571 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 572 return failure(); 573 574 Block *currentBlock = cloneOp->getBlock(); 575 Operation *redundantDealloc = nullptr; 576 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 577 redundantDealloc = cloneDeallocOp; 578 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 579 redundantDealloc = sourceDeallocOp; 580 } 581 582 if (!redundantDealloc) 583 return failure(); 584 585 // Safety check that there are no other deallocations inbetween 586 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 587 // of source before the uses of the clone. With alias information, we could 588 // restrict this to only fail of the dealloc's operand is an alias 589 // of the source. 590 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 591 pos = pos->getNextNode()) { 592 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 593 if (!effectInterface) 594 continue; 595 if (effectInterface.hasEffect<MemoryEffects::Free>()) 596 return failure(); 597 } 598 599 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 600 source); 601 rewriter.eraseOp(redundantDealloc); 602 return success(); 603 } 604 }; 605 606 } // end anonymous namespace. 607 608 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 609 MLIRContext *context) { 610 results.insert<SimplifyClones>(context); 611 } 612 613 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 614 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 615 } 616 617 //===----------------------------------------------------------------------===// 618 // DeallocOp 619 //===----------------------------------------------------------------------===// 620 621 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 622 SmallVectorImpl<OpFoldResult> &results) { 623 /// dealloc(memrefcast) -> dealloc 624 return foldMemRefCast(*this); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // DimOp 629 //===----------------------------------------------------------------------===// 630 631 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 632 int64_t index) { 633 auto loc = result.location; 634 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 635 build(builder, result, memref, indexValue); 636 } 637 638 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 639 Value index) { 640 auto indexTy = builder.getIndexType(); 641 build(builder, result, indexTy, memref, index); 642 } 643 644 Optional<int64_t> DimOp::getConstantIndex() { 645 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 646 return constantOp.getValue().cast<IntegerAttr>().getInt(); 647 return {}; 648 } 649 650 static LogicalResult verify(DimOp op) { 651 // Assume unknown index to be in range. 652 Optional<int64_t> index = op.getConstantIndex(); 653 if (!index.hasValue()) 654 return success(); 655 656 // Check that constant index is not knowingly out of range. 657 auto type = op.memrefOrTensor().getType(); 658 if (auto memrefType = type.dyn_cast<MemRefType>()) { 659 if (index.getValue() >= memrefType.getRank()) 660 return op.emitOpError("index is out of range"); 661 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 662 if (index.getValue() >= tensorType.getRank()) 663 return op.emitOpError("index is out of range"); 664 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 665 // Assume index to be in range. 666 } else { 667 llvm_unreachable("expected operand with memref type"); 668 } 669 return success(); 670 } 671 672 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 673 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 674 675 // All forms of folding require a known index. 676 if (!index) 677 return {}; 678 679 auto argTy = memrefOrTensor().getType(); 680 // Fold if the shape extent along the given index is known. 681 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 682 // Folding for unranked types (UnrankedMemRefType) is not supported. 683 if (!shapedTy.hasRank()) 684 return {}; 685 if (!shapedTy.isDynamicDim(index.getInt())) { 686 Builder builder(getContext()); 687 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 688 } 689 } 690 691 Operation *definingOp = memrefOrTensor().getDefiningOp(); 692 693 // dim(memref.tensor_load(memref)) -> dim(memref) 694 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 695 setOperand(0, tensorLoadOp.memref()); 696 return getResult(); 697 } 698 699 // Fold dim to the operand of tensor.generate. 700 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 701 auto resultType = 702 fromElements.getResult().getType().cast<RankedTensorType>(); 703 // The case where the type encodes the size of the dimension is handled 704 // above. 705 assert(resultType.getShape()[index.getInt()] == 706 RankedTensorType::kDynamicSize); 707 708 // Find the operand of the fromElements that corresponds to this index. 709 auto dynExtents = fromElements.dynamicExtents().begin(); 710 for (auto dim : resultType.getShape().take_front(index.getInt())) 711 if (dim == RankedTensorType::kDynamicSize) 712 dynExtents++; 713 714 return Value{*dynExtents}; 715 } 716 717 // The size at the given index is now known to be a dynamic size. 718 unsigned unsignedIndex = index.getValue().getZExtValue(); 719 720 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 721 assert(subtensor.isDynamicSize(unsignedIndex) && 722 "Expected dynamic subtensor size"); 723 return subtensor.getDynamicSize(unsignedIndex); 724 } 725 726 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 727 auto memrefType = argTy.dyn_cast<MemRefType>(); 728 if (!memrefType) 729 return {}; 730 731 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 732 return *(alloc.getDynamicSizes().begin() + 733 memrefType.getDynamicDimIndex(unsignedIndex)); 734 735 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 736 return *(alloca.getDynamicSizes().begin() + 737 memrefType.getDynamicDimIndex(unsignedIndex)); 738 739 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 740 return *(view.getDynamicSizes().begin() + 741 memrefType.getDynamicDimIndex(unsignedIndex)); 742 743 if (auto sizeInterface = 744 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 745 assert(sizeInterface.isDynamicSize(unsignedIndex) && 746 "Expected dynamic subview size"); 747 return sizeInterface.getDynamicSize(unsignedIndex); 748 } 749 750 // dim(memrefcast) -> dim 751 if (succeeded(foldMemRefCast(*this))) 752 return getResult(); 753 754 return {}; 755 } 756 757 namespace { 758 /// Fold dim of a memref reshape operation to a load into the reshape's shape 759 /// operand. 760 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 761 using OpRewritePattern<DimOp>::OpRewritePattern; 762 763 LogicalResult matchAndRewrite(DimOp dim, 764 PatternRewriter &rewriter) const override { 765 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 766 767 if (!reshape) 768 return failure(); 769 770 // Place the load directly after the reshape to ensure that the shape memref 771 // was not mutated. 772 rewriter.setInsertionPointAfter(reshape); 773 Location loc = dim.getLoc(); 774 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index()); 775 if (load.getType() != dim.getType()) 776 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load); 777 rewriter.replaceOp(dim, load); 778 return success(); 779 } 780 }; 781 782 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 783 template <typename CastOpTy> 784 struct DimOfCastOp : public OpRewritePattern<DimOp> { 785 using OpRewritePattern<DimOp>::OpRewritePattern; 786 787 LogicalResult matchAndRewrite(DimOp dimOp, 788 PatternRewriter &rewriter) const override { 789 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 790 if (!castOp) 791 return failure(); 792 Value newSource = castOp.getOperand(); 793 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 794 return success(); 795 } 796 }; 797 798 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th 799 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. 800 /// TODO(ravishankarm): This is better put as a interface utility method 801 /// somewhere, but that would imply the interface will depend on the `tensor` 802 /// dialect. Ideally maybe a utility method in the `tensor` dialect. 803 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, 804 int64_t dimIndex) { 805 unsigned resultNumber = result.getResultNumber(); 806 auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); 807 Location loc = result.getOwner()->getLoc(); 808 if (!shapedTypeOp) 809 return nullptr; 810 811 // The interface exposes two methods, one that returns the shape of all the 812 // results as `Value` and other that returns the shape as a list of 813 // `SmallVector<Value>`. The former takes precedence over the latter. So first 814 // check if the op implements the first interface method or the second, and 815 // get the value to use appropriately. 816 SmallVector<Value> reifiedResultShapes; 817 if (succeeded(shapedTypeOp.reifyReturnTypeShapes( 818 builder, result.getOwner()->getOperands(), reifiedResultShapes))) { 819 if (reifiedResultShapes.size() <= resultNumber) 820 return nullptr; 821 Value resultShape = reifiedResultShapes[resultNumber]; 822 auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); 823 if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) 824 return nullptr; 825 return builder.create<tensor::ExtractOp>( 826 loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); 827 } 828 829 SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; 830 if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( 831 builder, reifiedResultShapesPerDim))) 832 return nullptr; 833 if (reifiedResultShapesPerDim.size() <= resultNumber || 834 reifiedResultShapesPerDim[resultNumber].size() != 835 static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) 836 return nullptr; 837 OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; 838 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) 839 return builder.createOrFold<ConstantIndexOp>( 840 loc, attr.cast<IntegerAttr>().getInt()); 841 return valueOrAttr.get<Value>(); 842 } 843 844 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 845 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> { 846 using OpRewritePattern<DimOp>::OpRewritePattern; 847 848 LogicalResult matchAndRewrite(DimOp dimOp, 849 PatternRewriter &rewriter) const override { 850 OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); 851 if (!dimValue) 852 return failure(); 853 auto shapedTypeOp = 854 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 855 if (!shapedTypeOp) 856 return failure(); 857 858 Optional<int64_t> dimIndex = dimOp.getConstantIndex(); 859 if (!dimIndex) 860 return failure(); 861 Value replacement = 862 getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); 863 if (!replacement) 864 return failure(); 865 rewriter.replaceOp(dimOp, replacement); 866 return success(); 867 } 868 }; 869 } // end anonymous namespace. 870 871 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 872 MLIRContext *context) { 873 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 874 DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context); 875 } 876 877 // --------------------------------------------------------------------------- 878 // DmaStartOp 879 // --------------------------------------------------------------------------- 880 881 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 882 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 883 ValueRange destIndices, Value numElements, 884 Value tagMemRef, ValueRange tagIndices, Value stride, 885 Value elementsPerStride) { 886 result.addOperands(srcMemRef); 887 result.addOperands(srcIndices); 888 result.addOperands(destMemRef); 889 result.addOperands(destIndices); 890 result.addOperands({numElements, tagMemRef}); 891 result.addOperands(tagIndices); 892 if (stride) 893 result.addOperands({stride, elementsPerStride}); 894 } 895 896 void DmaStartOp::print(OpAsmPrinter &p) { 897 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 898 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 899 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 900 << ']'; 901 if (isStrided()) 902 p << ", " << getStride() << ", " << getNumElementsPerStride(); 903 904 p.printOptionalAttrDict((*this)->getAttrs()); 905 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 906 << ", " << getTagMemRef().getType(); 907 } 908 909 // Parse DmaStartOp. 910 // Ex: 911 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 912 // %tag[%index], %stride, %num_elt_per_stride : 913 // : memref<3076 x f32, 0>, 914 // memref<1024 x f32, 2>, 915 // memref<1 x i32> 916 // 917 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 918 OpAsmParser::OperandType srcMemRefInfo; 919 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 920 OpAsmParser::OperandType dstMemRefInfo; 921 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 922 OpAsmParser::OperandType numElementsInfo; 923 OpAsmParser::OperandType tagMemrefInfo; 924 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 925 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 926 927 SmallVector<Type, 3> types; 928 auto indexType = parser.getBuilder().getIndexType(); 929 930 // Parse and resolve the following list of operands: 931 // *) source memref followed by its indices (in square brackets). 932 // *) destination memref followed by its indices (in square brackets). 933 // *) dma size in KiB. 934 if (parser.parseOperand(srcMemRefInfo) || 935 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 936 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 937 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 938 parser.parseComma() || parser.parseOperand(numElementsInfo) || 939 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 940 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 941 return failure(); 942 943 // Parse optional stride and elements per stride. 944 if (parser.parseTrailingOperandList(strideInfo)) 945 return failure(); 946 947 bool isStrided = strideInfo.size() == 2; 948 if (!strideInfo.empty() && !isStrided) { 949 return parser.emitError(parser.getNameLoc(), 950 "expected two stride related operands"); 951 } 952 953 if (parser.parseColonTypeList(types)) 954 return failure(); 955 if (types.size() != 3) 956 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 957 958 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 959 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 960 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 961 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 962 // size should be an index. 963 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 964 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 965 // tag indices should be index. 966 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 967 return failure(); 968 969 if (isStrided) { 970 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 971 return failure(); 972 } 973 974 return success(); 975 } 976 977 LogicalResult DmaStartOp::verify() { 978 unsigned numOperands = getNumOperands(); 979 980 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 981 // the number of elements. 982 if (numOperands < 4) 983 return emitOpError("expected at least 4 operands"); 984 985 // Check types of operands. The order of these calls is important: the later 986 // calls rely on some type properties to compute the operand position. 987 // 1. Source memref. 988 if (!getSrcMemRef().getType().isa<MemRefType>()) 989 return emitOpError("expected source to be of memref type"); 990 if (numOperands < getSrcMemRefRank() + 4) 991 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 992 << " operands"; 993 if (!getSrcIndices().empty() && 994 !llvm::all_of(getSrcIndices().getTypes(), 995 [](Type t) { return t.isIndex(); })) 996 return emitOpError("expected source indices to be of index type"); 997 998 // 2. Destination memref. 999 if (!getDstMemRef().getType().isa<MemRefType>()) 1000 return emitOpError("expected destination to be of memref type"); 1001 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 1002 if (numOperands < numExpectedOperands) 1003 return emitOpError() << "expected at least " << numExpectedOperands 1004 << " operands"; 1005 if (!getDstIndices().empty() && 1006 !llvm::all_of(getDstIndices().getTypes(), 1007 [](Type t) { return t.isIndex(); })) 1008 return emitOpError("expected destination indices to be of index type"); 1009 1010 // 3. Number of elements. 1011 if (!getNumElements().getType().isIndex()) 1012 return emitOpError("expected num elements to be of index type"); 1013 1014 // 4. Tag memref. 1015 if (!getTagMemRef().getType().isa<MemRefType>()) 1016 return emitOpError("expected tag to be of memref type"); 1017 numExpectedOperands += getTagMemRefRank(); 1018 if (numOperands < numExpectedOperands) 1019 return emitOpError() << "expected at least " << numExpectedOperands 1020 << " operands"; 1021 if (!getTagIndices().empty() && 1022 !llvm::all_of(getTagIndices().getTypes(), 1023 [](Type t) { return t.isIndex(); })) 1024 return emitOpError("expected tag indices to be of index type"); 1025 1026 // Optional stride-related operands must be either both present or both 1027 // absent. 1028 if (numOperands != numExpectedOperands && 1029 numOperands != numExpectedOperands + 2) 1030 return emitOpError("incorrect number of operands"); 1031 1032 // 5. Strides. 1033 if (isStrided()) { 1034 if (!getStride().getType().isIndex() || 1035 !getNumElementsPerStride().getType().isIndex()) 1036 return emitOpError( 1037 "expected stride and num elements per stride to be of type index"); 1038 } 1039 1040 return success(); 1041 } 1042 1043 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1044 SmallVectorImpl<OpFoldResult> &results) { 1045 /// dma_start(memrefcast) -> dma_start 1046 return foldMemRefCast(*this); 1047 } 1048 1049 // --------------------------------------------------------------------------- 1050 // DmaWaitOp 1051 // --------------------------------------------------------------------------- 1052 1053 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 1054 Value tagMemRef, ValueRange tagIndices, 1055 Value numElements) { 1056 result.addOperands(tagMemRef); 1057 result.addOperands(tagIndices); 1058 result.addOperands(numElements); 1059 } 1060 1061 void DmaWaitOp::print(OpAsmPrinter &p) { 1062 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 1063 << "], " << getNumElements(); 1064 p.printOptionalAttrDict((*this)->getAttrs()); 1065 p << " : " << getTagMemRef().getType(); 1066 } 1067 1068 // Parse DmaWaitOp. 1069 // Eg: 1070 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 1071 // 1072 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 1073 OpAsmParser::OperandType tagMemrefInfo; 1074 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 1075 Type type; 1076 auto indexType = parser.getBuilder().getIndexType(); 1077 OpAsmParser::OperandType numElementsInfo; 1078 1079 // Parse tag memref, its indices, and dma size. 1080 if (parser.parseOperand(tagMemrefInfo) || 1081 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1082 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1083 parser.parseColonType(type) || 1084 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 1085 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 1086 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1087 return failure(); 1088 1089 return success(); 1090 } 1091 1092 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1093 SmallVectorImpl<OpFoldResult> &results) { 1094 /// dma_wait(memrefcast) -> dma_wait 1095 return foldMemRefCast(*this); 1096 } 1097 1098 LogicalResult DmaWaitOp::verify() { 1099 // Mandatory non-variadic operands are tag and the number of elements. 1100 if (getNumOperands() < 2) 1101 return emitOpError() << "expected at least 2 operands"; 1102 1103 // Check types of operands. The order of these calls is important: the later 1104 // calls rely on some type properties to compute the operand position. 1105 if (!getTagMemRef().getType().isa<MemRefType>()) 1106 return emitOpError() << "expected tag to be of memref type"; 1107 1108 if (getNumOperands() != 2 + getTagMemRefRank()) 1109 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1110 << " operands"; 1111 1112 if (!getTagIndices().empty() && 1113 !llvm::all_of(getTagIndices().getTypes(), 1114 [](Type t) { return t.isIndex(); })) 1115 return emitOpError() << "expected tag indices to be of index type"; 1116 1117 if (!getNumElements().getType().isIndex()) 1118 return emitOpError() 1119 << "expected the number of elements to be of index type"; 1120 1121 return success(); 1122 } 1123 1124 //===----------------------------------------------------------------------===// 1125 // GlobalOp 1126 //===----------------------------------------------------------------------===// 1127 1128 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1129 TypeAttr type, 1130 Attribute initialValue) { 1131 p << type; 1132 if (!op.isExternal()) { 1133 p << " = "; 1134 if (op.isUninitialized()) 1135 p << "uninitialized"; 1136 else 1137 p.printAttributeWithoutType(initialValue); 1138 } 1139 } 1140 1141 static ParseResult 1142 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1143 Attribute &initialValue) { 1144 Type type; 1145 if (parser.parseType(type)) 1146 return failure(); 1147 1148 auto memrefType = type.dyn_cast<MemRefType>(); 1149 if (!memrefType || !memrefType.hasStaticShape()) 1150 return parser.emitError(parser.getNameLoc()) 1151 << "type should be static shaped memref, but got " << type; 1152 typeAttr = TypeAttr::get(type); 1153 1154 if (parser.parseOptionalEqual()) 1155 return success(); 1156 1157 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1158 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1159 return success(); 1160 } 1161 1162 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1163 if (parser.parseAttribute(initialValue, tensorType)) 1164 return failure(); 1165 if (!initialValue.isa<ElementsAttr>()) 1166 return parser.emitError(parser.getNameLoc()) 1167 << "initial value should be a unit or elements attribute"; 1168 return success(); 1169 } 1170 1171 static LogicalResult verify(GlobalOp op) { 1172 auto memrefType = op.type().dyn_cast<MemRefType>(); 1173 if (!memrefType || !memrefType.hasStaticShape()) 1174 return op.emitOpError("type should be static shaped memref, but got ") 1175 << op.type(); 1176 1177 // Verify that the initial value, if present, is either a unit attribute or 1178 // an elements attribute. 1179 if (op.initial_value().hasValue()) { 1180 Attribute initValue = op.initial_value().getValue(); 1181 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1182 return op.emitOpError("initial value should be a unit or elements " 1183 "attribute, but got ") 1184 << initValue; 1185 1186 // Check that the type of the initial value is compatible with the type of 1187 // the global variable. 1188 if (initValue.isa<ElementsAttr>()) { 1189 Type initType = initValue.getType(); 1190 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1191 if (initType != tensorType) 1192 return op.emitOpError("initial value expected to be of type ") 1193 << tensorType << ", but was of type " << initType; 1194 } 1195 } 1196 1197 // TODO: verify visibility for declarations. 1198 return success(); 1199 } 1200 1201 //===----------------------------------------------------------------------===// 1202 // GetGlobalOp 1203 //===----------------------------------------------------------------------===// 1204 1205 LogicalResult 1206 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1207 // Verify that the result type is same as the type of the referenced 1208 // memref.global op. 1209 auto global = 1210 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1211 if (!global) 1212 return emitOpError("'") 1213 << name() << "' does not reference a valid global memref"; 1214 1215 Type resultType = result().getType(); 1216 if (global.type() != resultType) 1217 return emitOpError("result type ") 1218 << resultType << " does not match type " << global.type() 1219 << " of the global memref @" << name(); 1220 return success(); 1221 } 1222 1223 //===----------------------------------------------------------------------===// 1224 // LoadOp 1225 //===----------------------------------------------------------------------===// 1226 1227 static LogicalResult verify(LoadOp op) { 1228 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1229 return op.emitOpError("incorrect number of indices for load"); 1230 return success(); 1231 } 1232 1233 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1234 /// load(memrefcast) -> load 1235 if (succeeded(foldMemRefCast(*this))) 1236 return getResult(); 1237 return OpFoldResult(); 1238 } 1239 1240 namespace { 1241 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1242 /// corresponding tensor. 1243 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1244 using OpRewritePattern<LoadOp>::OpRewritePattern; 1245 1246 LogicalResult matchAndRewrite(LoadOp load, 1247 PatternRewriter &rewriter) const override { 1248 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1249 if (!buffercast) 1250 return failure(); 1251 1252 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1253 load.indices()); 1254 return success(); 1255 } 1256 }; 1257 } // end anonymous namespace. 1258 1259 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1260 MLIRContext *context) { 1261 results.add<LoadOfBufferCast>(context); 1262 } 1263 1264 //===----------------------------------------------------------------------===// 1265 // PrefetchOp 1266 //===----------------------------------------------------------------------===// 1267 1268 static void print(OpAsmPrinter &p, PrefetchOp op) { 1269 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1270 p.printOperands(op.indices()); 1271 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1272 p << ", locality<" << op.localityHint(); 1273 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1274 p.printOptionalAttrDict( 1275 op->getAttrs(), 1276 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1277 p << " : " << op.getMemRefType(); 1278 } 1279 1280 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1281 OperationState &result) { 1282 OpAsmParser::OperandType memrefInfo; 1283 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1284 IntegerAttr localityHint; 1285 MemRefType type; 1286 StringRef readOrWrite, cacheType; 1287 1288 auto indexTy = parser.getBuilder().getIndexType(); 1289 auto i32Type = parser.getBuilder().getIntegerType(32); 1290 if (parser.parseOperand(memrefInfo) || 1291 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1292 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1293 parser.parseComma() || parser.parseKeyword("locality") || 1294 parser.parseLess() || 1295 parser.parseAttribute(localityHint, i32Type, "localityHint", 1296 result.attributes) || 1297 parser.parseGreater() || parser.parseComma() || 1298 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1299 parser.resolveOperand(memrefInfo, type, result.operands) || 1300 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1301 return failure(); 1302 1303 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1304 return parser.emitError(parser.getNameLoc(), 1305 "rw specifier has to be 'read' or 'write'"); 1306 result.addAttribute( 1307 PrefetchOp::getIsWriteAttrName(), 1308 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1309 1310 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1311 return parser.emitError(parser.getNameLoc(), 1312 "cache type has to be 'data' or 'instr'"); 1313 1314 result.addAttribute( 1315 PrefetchOp::getIsDataCacheAttrName(), 1316 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1317 1318 return success(); 1319 } 1320 1321 static LogicalResult verify(PrefetchOp op) { 1322 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1323 return op.emitOpError("too few indices"); 1324 1325 return success(); 1326 } 1327 1328 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1329 SmallVectorImpl<OpFoldResult> &results) { 1330 // prefetch(memrefcast) -> prefetch 1331 return foldMemRefCast(*this); 1332 } 1333 1334 //===----------------------------------------------------------------------===// 1335 // ReinterpretCastOp 1336 //===----------------------------------------------------------------------===// 1337 1338 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1339 /// `staticSizes` and `staticStrides` are automatically filled with 1340 /// source-memref-rank sentinel values that encode dynamic entries. 1341 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1342 MemRefType resultType, Value source, 1343 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1344 ArrayRef<OpFoldResult> strides, 1345 ArrayRef<NamedAttribute> attrs) { 1346 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1347 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1348 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1349 ShapedType::kDynamicStrideOrOffset); 1350 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1351 ShapedType::kDynamicSize); 1352 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1353 ShapedType::kDynamicStrideOrOffset); 1354 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1355 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1356 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1357 result.addAttributes(attrs); 1358 } 1359 1360 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1361 MemRefType resultType, Value source, 1362 int64_t offset, ArrayRef<int64_t> sizes, 1363 ArrayRef<int64_t> strides, 1364 ArrayRef<NamedAttribute> attrs) { 1365 SmallVector<OpFoldResult> sizeValues = 1366 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1367 return b.getI64IntegerAttr(v); 1368 })); 1369 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1370 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1371 return b.getI64IntegerAttr(v); 1372 })); 1373 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1374 strideValues, attrs); 1375 } 1376 1377 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1378 MemRefType resultType, Value source, Value offset, 1379 ValueRange sizes, ValueRange strides, 1380 ArrayRef<NamedAttribute> attrs) { 1381 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1382 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1383 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1384 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1385 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1386 } 1387 1388 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1389 // completed automatically, like we have for subview and subtensor. 1390 static LogicalResult verify(ReinterpretCastOp op) { 1391 // The source and result memrefs should be in the same memory space. 1392 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1393 auto resultType = op.getType().cast<MemRefType>(); 1394 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1395 return op.emitError("different memory spaces specified for source type ") 1396 << srcType << " and result memref type " << resultType; 1397 if (srcType.getElementType() != resultType.getElementType()) 1398 return op.emitError("different element types specified for source type ") 1399 << srcType << " and result memref type " << resultType; 1400 1401 // Match sizes in result memref type and in static_sizes attribute. 1402 for (auto &en : 1403 llvm::enumerate(llvm::zip(resultType.getShape(), 1404 extractFromI64ArrayAttr(op.static_sizes())))) { 1405 int64_t resultSize = std::get<0>(en.value()); 1406 int64_t expectedSize = std::get<1>(en.value()); 1407 if (resultSize != expectedSize) 1408 return op.emitError("expected result type with size = ") 1409 << expectedSize << " instead of " << resultSize 1410 << " in dim = " << en.index(); 1411 } 1412 1413 // Match offset and strides in static_offset and static_strides attributes if 1414 // result memref type has an affine map specified. 1415 if (!resultType.getAffineMaps().empty()) { 1416 int64_t resultOffset; 1417 SmallVector<int64_t, 4> resultStrides; 1418 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1419 return failure(); 1420 1421 // Match offset in result memref type and in static_offsets attribute. 1422 int64_t expectedOffset = 1423 extractFromI64ArrayAttr(op.static_offsets()).front(); 1424 if (resultOffset != expectedOffset) 1425 return op.emitError("expected result type with offset = ") 1426 << resultOffset << " instead of " << expectedOffset; 1427 1428 // Match strides in result memref type and in static_strides attribute. 1429 for (auto &en : llvm::enumerate(llvm::zip( 1430 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1431 int64_t resultStride = std::get<0>(en.value()); 1432 int64_t expectedStride = std::get<1>(en.value()); 1433 if (resultStride != expectedStride) 1434 return op.emitError("expected result type with stride = ") 1435 << expectedStride << " instead of " << resultStride 1436 << " in dim = " << en.index(); 1437 } 1438 } 1439 return success(); 1440 } 1441 1442 //===----------------------------------------------------------------------===// 1443 // ReshapeOp 1444 //===----------------------------------------------------------------------===// 1445 1446 static LogicalResult verify(ReshapeOp op) { 1447 Type operandType = op.source().getType(); 1448 Type resultType = op.result().getType(); 1449 1450 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1451 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1452 if (operandElementType != resultElementType) 1453 return op.emitOpError("element types of source and destination memref " 1454 "types should be the same"); 1455 1456 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1457 if (!operandMemRefType.getAffineMaps().empty()) 1458 return op.emitOpError( 1459 "source memref type should have identity affine map"); 1460 1461 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1462 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1463 if (resultMemRefType) { 1464 if (!resultMemRefType.getAffineMaps().empty()) 1465 return op.emitOpError( 1466 "result memref type should have identity affine map"); 1467 if (shapeSize == ShapedType::kDynamicSize) 1468 return op.emitOpError("cannot use shape operand with dynamic length to " 1469 "reshape to statically-ranked memref type"); 1470 if (shapeSize != resultMemRefType.getRank()) 1471 return op.emitOpError( 1472 "length of shape operand differs from the result's memref rank"); 1473 } 1474 return success(); 1475 } 1476 1477 //===----------------------------------------------------------------------===// 1478 // StoreOp 1479 //===----------------------------------------------------------------------===// 1480 1481 static LogicalResult verify(StoreOp op) { 1482 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1483 return op.emitOpError("store index operand count not equal to memref rank"); 1484 1485 return success(); 1486 } 1487 1488 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1489 SmallVectorImpl<OpFoldResult> &results) { 1490 /// store(memrefcast) -> store 1491 return foldMemRefCast(*this, getValueToStore()); 1492 } 1493 1494 //===----------------------------------------------------------------------===// 1495 // SubViewOp 1496 //===----------------------------------------------------------------------===// 1497 1498 namespace { 1499 /// Helpers to write more idiomatic operations. 1500 namespace saturated_arith { 1501 struct Wrapper { 1502 explicit Wrapper(int64_t v) : v(v) {} 1503 operator int64_t() { return v; } 1504 int64_t v; 1505 }; 1506 Wrapper operator+(Wrapper a, int64_t b) { 1507 if (ShapedType::isDynamicStrideOrOffset(a) || 1508 ShapedType::isDynamicStrideOrOffset(b)) 1509 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1510 return Wrapper(a.v + b); 1511 } 1512 Wrapper operator*(Wrapper a, int64_t b) { 1513 if (ShapedType::isDynamicStrideOrOffset(a) || 1514 ShapedType::isDynamicStrideOrOffset(b)) 1515 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1516 return Wrapper(a.v * b); 1517 } 1518 } // end namespace saturated_arith 1519 } // end namespace 1520 1521 /// A subview result type can be fully inferred from the source type and the 1522 /// static representation of offsets, sizes and strides. Special sentinels 1523 /// encode the dynamic case. 1524 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1525 ArrayRef<int64_t> leadingStaticOffsets, 1526 ArrayRef<int64_t> leadingStaticSizes, 1527 ArrayRef<int64_t> leadingStaticStrides) { 1528 // A subview may specify only a leading subset of offset/sizes/strides in 1529 // which case we complete with offset=0, sizes from memref type and strides=1. 1530 unsigned rank = sourceMemRefType.getRank(); 1531 assert(leadingStaticOffsets.size() <= rank && 1532 "unexpected leadingStaticOffsets overflow"); 1533 assert(leadingStaticSizes.size() <= rank && 1534 "unexpected leadingStaticSizes overflow"); 1535 assert(leadingStaticStrides.size() <= rank && 1536 "unexpected leadingStaticStrides overflow"); 1537 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1538 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1539 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1540 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1541 unsigned numTrailingSizes = rank - staticSizes.size(); 1542 unsigned numTrailingStrides = rank - staticStrides.size(); 1543 staticOffsets.append(numTrailingOffsets, 0); 1544 llvm::append_range(staticSizes, 1545 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1546 staticStrides.append(numTrailingStrides, 1); 1547 1548 // Extract source offset and strides. 1549 int64_t sourceOffset; 1550 SmallVector<int64_t, 4> sourceStrides; 1551 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1552 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1553 (void)res; 1554 1555 // Compute target offset whose value is: 1556 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1557 int64_t targetOffset = sourceOffset; 1558 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1559 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1560 using namespace saturated_arith; 1561 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1562 } 1563 1564 // Compute target stride whose value is: 1565 // `sourceStrides_i * staticStrides_i`. 1566 SmallVector<int64_t, 4> targetStrides; 1567 targetStrides.reserve(staticOffsets.size()); 1568 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1569 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1570 using namespace saturated_arith; 1571 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1572 } 1573 1574 // The type is now known. 1575 return MemRefType::get( 1576 staticSizes, sourceMemRefType.getElementType(), 1577 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1578 sourceMemRefType.getContext()), 1579 sourceMemRefType.getMemorySpace()); 1580 } 1581 1582 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1583 ArrayRef<OpFoldResult> leadingStaticOffsets, 1584 ArrayRef<OpFoldResult> leadingStaticSizes, 1585 ArrayRef<OpFoldResult> leadingStaticStrides) { 1586 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1587 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1588 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1589 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1590 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1591 ShapedType::kDynamicSize); 1592 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1593 staticStrides, ShapedType::kDynamicStrideOrOffset); 1594 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1595 staticSizes, staticStrides) 1596 .cast<MemRefType>(); 1597 } 1598 1599 Type SubViewOp::inferRankReducedResultType( 1600 unsigned resultRank, MemRefType sourceRankedTensorType, 1601 ArrayRef<int64_t> leadingStaticOffsets, 1602 ArrayRef<int64_t> leadingStaticSizes, 1603 ArrayRef<int64_t> leadingStaticStrides) { 1604 auto inferredType = 1605 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1606 leadingStaticSizes, leadingStaticStrides) 1607 .cast<MemRefType>(); 1608 assert(inferredType.getRank() >= resultRank && "expected "); 1609 int rankDiff = inferredType.getRank() - resultRank; 1610 if (rankDiff > 0) { 1611 auto shape = inferredType.getShape(); 1612 llvm::SmallDenseSet<unsigned> dimsToProject; 1613 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1614 SmallVector<int64_t> projectedShape; 1615 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1616 if (!dimsToProject.contains(pos)) 1617 projectedShape.push_back(shape[pos]); 1618 1619 AffineMap map; 1620 auto maps = inferredType.getAffineMaps(); 1621 if (!maps.empty() && maps.front()) 1622 map = getProjectedMap(maps.front(), dimsToProject); 1623 inferredType = 1624 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1625 inferredType.getMemorySpace()); 1626 } 1627 return inferredType; 1628 } 1629 1630 Type SubViewOp::inferRankReducedResultType( 1631 unsigned resultRank, MemRefType sourceRankedTensorType, 1632 ArrayRef<OpFoldResult> leadingStaticOffsets, 1633 ArrayRef<OpFoldResult> leadingStaticSizes, 1634 ArrayRef<OpFoldResult> leadingStaticStrides) { 1635 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1636 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1637 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1638 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1639 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1640 ShapedType::kDynamicSize); 1641 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1642 staticStrides, ShapedType::kDynamicStrideOrOffset); 1643 return SubViewOp::inferRankReducedResultType( 1644 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1645 staticStrides); 1646 } 1647 // Build a SubViewOp with mixed static and dynamic entries and custom result 1648 // type. If the type passed is nullptr, it is inferred. 1649 void SubViewOp::build(OpBuilder &b, OperationState &result, 1650 MemRefType resultType, Value source, 1651 ArrayRef<OpFoldResult> offsets, 1652 ArrayRef<OpFoldResult> sizes, 1653 ArrayRef<OpFoldResult> strides, 1654 ArrayRef<NamedAttribute> attrs) { 1655 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1656 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1657 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1658 ShapedType::kDynamicStrideOrOffset); 1659 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1660 ShapedType::kDynamicSize); 1661 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1662 ShapedType::kDynamicStrideOrOffset); 1663 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1664 // Structuring implementation this way avoids duplication between builders. 1665 if (!resultType) { 1666 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1667 staticSizes, staticStrides) 1668 .cast<MemRefType>(); 1669 } 1670 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1671 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1672 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1673 result.addAttributes(attrs); 1674 } 1675 1676 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1677 // type. 1678 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1679 ArrayRef<OpFoldResult> offsets, 1680 ArrayRef<OpFoldResult> sizes, 1681 ArrayRef<OpFoldResult> strides, 1682 ArrayRef<NamedAttribute> attrs) { 1683 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1684 } 1685 1686 // Build a SubViewOp with static entries and inferred result type. 1687 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1688 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1689 ArrayRef<int64_t> strides, 1690 ArrayRef<NamedAttribute> attrs) { 1691 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1692 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1693 return b.getI64IntegerAttr(v); 1694 })); 1695 SmallVector<OpFoldResult> sizeValues = 1696 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1697 return b.getI64IntegerAttr(v); 1698 })); 1699 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1700 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1701 return b.getI64IntegerAttr(v); 1702 })); 1703 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1704 } 1705 1706 // Build a SubViewOp with dynamic entries and custom result type. If the 1707 // type passed is nullptr, it is inferred. 1708 void SubViewOp::build(OpBuilder &b, OperationState &result, 1709 MemRefType resultType, Value source, 1710 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1711 ArrayRef<int64_t> strides, 1712 ArrayRef<NamedAttribute> attrs) { 1713 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1714 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1715 return b.getI64IntegerAttr(v); 1716 })); 1717 SmallVector<OpFoldResult> sizeValues = 1718 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1719 return b.getI64IntegerAttr(v); 1720 })); 1721 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1722 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1723 return b.getI64IntegerAttr(v); 1724 })); 1725 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1726 attrs); 1727 } 1728 1729 // Build a SubViewOp with dynamic entries and custom result type. If the type 1730 // passed is nullptr, it is inferred. 1731 void SubViewOp::build(OpBuilder &b, OperationState &result, 1732 MemRefType resultType, Value source, ValueRange offsets, 1733 ValueRange sizes, ValueRange strides, 1734 ArrayRef<NamedAttribute> attrs) { 1735 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1736 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1737 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1738 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1739 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1740 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1741 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1742 } 1743 1744 // Build a SubViewOp with dynamic entries and inferred result type. 1745 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1746 ValueRange offsets, ValueRange sizes, ValueRange strides, 1747 ArrayRef<NamedAttribute> attrs) { 1748 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1749 } 1750 1751 /// For ViewLikeOpInterface. 1752 Value SubViewOp::getViewSource() { return source(); } 1753 1754 enum SubViewVerificationResult { 1755 Success, 1756 RankTooLarge, 1757 SizeMismatch, 1758 ElemTypeMismatch, 1759 MemSpaceMismatch, 1760 AffineMapMismatch 1761 }; 1762 1763 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1764 /// This function is slight variant of `is subsequence` algorithm where 1765 /// not matching dimension must be 1. 1766 static SubViewVerificationResult 1767 isRankReducedType(Type originalType, Type candidateReducedType, 1768 std::string *errMsg = nullptr) { 1769 if (originalType == candidateReducedType) 1770 return SubViewVerificationResult::Success; 1771 if (!originalType.isa<MemRefType>()) 1772 return SubViewVerificationResult::Success; 1773 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1774 return SubViewVerificationResult::Success; 1775 1776 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1777 ShapedType candidateReducedShapedType = 1778 candidateReducedType.cast<ShapedType>(); 1779 1780 // Rank and size logic is valid for all ShapedTypes. 1781 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1782 ArrayRef<int64_t> candidateReducedShape = 1783 candidateReducedShapedType.getShape(); 1784 unsigned originalRank = originalShape.size(), 1785 candidateReducedRank = candidateReducedShape.size(); 1786 if (candidateReducedRank > originalRank) 1787 return SubViewVerificationResult::RankTooLarge; 1788 1789 auto optionalUnusedDimsMask = 1790 computeRankReductionMask(originalShape, candidateReducedShape); 1791 1792 // Sizes cannot be matched in case empty vector is returned. 1793 if (!optionalUnusedDimsMask.hasValue()) 1794 return SubViewVerificationResult::SizeMismatch; 1795 1796 if (originalShapedType.getElementType() != 1797 candidateReducedShapedType.getElementType()) 1798 return SubViewVerificationResult::ElemTypeMismatch; 1799 1800 // Strided layout logic is relevant for MemRefType only. 1801 MemRefType original = originalType.cast<MemRefType>(); 1802 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1803 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1804 return SubViewVerificationResult::MemSpaceMismatch; 1805 1806 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1807 auto inferredType = 1808 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1809 AffineMap candidateLayout; 1810 if (candidateReduced.getAffineMaps().empty()) 1811 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1812 else 1813 candidateLayout = candidateReduced.getAffineMaps().front(); 1814 assert(inferredType.getNumResults() == 1 && 1815 candidateLayout.getNumResults() == 1); 1816 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1817 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1818 if (errMsg) { 1819 llvm::raw_string_ostream os(*errMsg); 1820 os << "inferred type: " << inferredType; 1821 } 1822 return SubViewVerificationResult::AffineMapMismatch; 1823 } 1824 // Check that the difference of the affine maps simplifies to 0. 1825 AffineExpr diffExpr = 1826 inferredType.getResult(0) - candidateLayout.getResult(0); 1827 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1828 inferredType.getNumSymbols()); 1829 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1830 if (!(cst && cst.getValue() == 0)) { 1831 if (errMsg) { 1832 llvm::raw_string_ostream os(*errMsg); 1833 os << "inferred type: " << inferredType; 1834 } 1835 return SubViewVerificationResult::AffineMapMismatch; 1836 } 1837 return SubViewVerificationResult::Success; 1838 } 1839 1840 template <typename OpTy> 1841 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1842 OpTy op, Type expectedType, 1843 StringRef errMsg = "") { 1844 auto memrefType = expectedType.cast<ShapedType>(); 1845 switch (result) { 1846 case SubViewVerificationResult::Success: 1847 return success(); 1848 case SubViewVerificationResult::RankTooLarge: 1849 return op.emitError("expected result rank to be smaller or equal to ") 1850 << "the source rank. " << errMsg; 1851 case SubViewVerificationResult::SizeMismatch: 1852 return op.emitError("expected result type to be ") 1853 << expectedType 1854 << " or a rank-reduced version. (mismatch of result sizes) " 1855 << errMsg; 1856 case SubViewVerificationResult::ElemTypeMismatch: 1857 return op.emitError("expected result element type to be ") 1858 << memrefType.getElementType() << errMsg; 1859 case SubViewVerificationResult::MemSpaceMismatch: 1860 return op.emitError("expected result and source memory spaces to match.") 1861 << errMsg; 1862 case SubViewVerificationResult::AffineMapMismatch: 1863 return op.emitError("expected result type to be ") 1864 << expectedType 1865 << " or a rank-reduced version. (mismatch of result affine map) " 1866 << errMsg; 1867 } 1868 llvm_unreachable("unexpected subview verification result"); 1869 } 1870 1871 /// Verifier for SubViewOp. 1872 static LogicalResult verify(SubViewOp op) { 1873 MemRefType baseType = op.getSourceType(); 1874 MemRefType subViewType = op.getType(); 1875 1876 // The base memref and the view memref should be in the same memory space. 1877 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1878 return op.emitError("different memory spaces specified for base memref " 1879 "type ") 1880 << baseType << " and subview memref type " << subViewType; 1881 1882 // Verify that the base memref type has a strided layout map. 1883 if (!isStrided(baseType)) 1884 return op.emitError("base type ") << baseType << " is not strided"; 1885 1886 // Verify result type against inferred type. 1887 auto expectedType = SubViewOp::inferResultType( 1888 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1889 extractFromI64ArrayAttr(op.static_sizes()), 1890 extractFromI64ArrayAttr(op.static_strides())); 1891 1892 std::string errMsg; 1893 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1894 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1895 } 1896 1897 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1898 return os << "range " << range.offset << ":" << range.size << ":" 1899 << range.stride; 1900 } 1901 1902 /// Return the list of Range (i.e. offset, size, stride). Each Range 1903 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1904 /// with `b` at location `loc`. 1905 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1906 OpBuilder &b, Location loc) { 1907 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1908 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1909 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1910 SmallVector<Range, 8> res; 1911 unsigned rank = ranks[0]; 1912 res.reserve(rank); 1913 for (unsigned idx = 0; idx < rank; ++idx) { 1914 Value offset = 1915 op.isDynamicOffset(idx) 1916 ? op.getDynamicOffset(idx) 1917 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1918 Value size = op.isDynamicSize(idx) 1919 ? op.getDynamicSize(idx) 1920 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1921 Value stride = 1922 op.isDynamicStride(idx) 1923 ? op.getDynamicStride(idx) 1924 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1925 res.emplace_back(Range{offset, size, stride}); 1926 } 1927 return res; 1928 } 1929 1930 /// Infer the canonical type of the result of a subview operation. Returns a 1931 /// type with rank `resultRank` that is either the rank of the rank-reduced 1932 /// type, or the non-rank-reduced type. 1933 static MemRefType 1934 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1935 ArrayRef<OpFoldResult> mixedOffsets, 1936 ArrayRef<OpFoldResult> mixedSizes, 1937 ArrayRef<OpFoldResult> mixedStrides) { 1938 auto resultType = 1939 SubViewOp::inferRankReducedResultType( 1940 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1941 .cast<MemRefType>(); 1942 if (resultType.getRank() != resultRank) { 1943 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1944 mixedSizes, mixedStrides) 1945 .cast<MemRefType>(); 1946 } 1947 return resultType; 1948 } 1949 1950 namespace { 1951 /// Pattern to rewrite a subview op with MemRefCast arguments. 1952 /// This essentially pushes memref.cast past its consuming subview when 1953 /// `canFoldIntoConsumerOp` is true. 1954 /// 1955 /// Example: 1956 /// ``` 1957 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1958 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1959 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1960 /// ``` 1961 /// is rewritten into: 1962 /// ``` 1963 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1964 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1965 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1966 /// ``` 1967 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1968 public: 1969 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1970 1971 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1972 PatternRewriter &rewriter) const override { 1973 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1974 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1975 return matchPattern(operand, matchConstantIndex()); 1976 })) 1977 return failure(); 1978 1979 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1980 if (!castOp) 1981 return failure(); 1982 1983 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1984 return failure(); 1985 1986 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1987 /// the cast source operand type and the SubViewOp static information. This 1988 /// is the resulting type if the MemRefCastOp were folded. 1989 auto resultType = getCanonicalSubViewResultType( 1990 subViewOp.getType().getRank(), 1991 castOp.source().getType().cast<MemRefType>(), 1992 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1993 subViewOp.getMixedStrides()); 1994 Value newSubView = rewriter.create<SubViewOp>( 1995 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1996 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1997 subViewOp.static_sizes(), subViewOp.static_strides()); 1998 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1999 newSubView); 2000 return success(); 2001 } 2002 }; 2003 } // namespace 2004 2005 /// Return the canonical type of the result of a subview. 2006 struct SubViewReturnTypeCanonicalizer { 2007 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2008 ArrayRef<OpFoldResult> mixedSizes, 2009 ArrayRef<OpFoldResult> mixedStrides) { 2010 return getCanonicalSubViewResultType(op.getType().getRank(), 2011 op.getSourceType(), mixedOffsets, 2012 mixedSizes, mixedStrides); 2013 } 2014 }; 2015 2016 /// A canonicalizer wrapper to replace SubViewOps. 2017 struct SubViewCanonicalizer { 2018 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2019 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2020 } 2021 }; 2022 2023 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2024 MLIRContext *context) { 2025 results 2026 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2027 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2028 SubViewOpMemRefCastFolder>(context); 2029 } 2030 2031 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2032 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2033 auto sourceShapedType = source().getType().cast<ShapedType>(); 2034 2035 if (resultShapedType.hasStaticShape() && 2036 resultShapedType == sourceShapedType) { 2037 return getViewSource(); 2038 } 2039 2040 return {}; 2041 } 2042 2043 //===----------------------------------------------------------------------===// 2044 // TensorLoadOp 2045 //===----------------------------------------------------------------------===// 2046 2047 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 2048 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 2049 // Approximate alias analysis by conservatively folding only when no there 2050 // is no interleaved operation. 2051 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 2052 bufferCast->getNextNode() == this->getOperation()) 2053 return bufferCast.tensor(); 2054 return {}; 2055 } 2056 2057 //===----------------------------------------------------------------------===// 2058 // TransposeOp 2059 //===----------------------------------------------------------------------===// 2060 2061 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2062 static MemRefType inferTransposeResultType(MemRefType memRefType, 2063 AffineMap permutationMap) { 2064 auto rank = memRefType.getRank(); 2065 auto originalSizes = memRefType.getShape(); 2066 // Compute permuted sizes. 2067 SmallVector<int64_t, 4> sizes(rank, 0); 2068 for (auto en : llvm::enumerate(permutationMap.getResults())) 2069 sizes[en.index()] = 2070 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2071 2072 // Compute permuted strides. 2073 int64_t offset; 2074 SmallVector<int64_t, 4> strides; 2075 auto res = getStridesAndOffset(memRefType, strides, offset); 2076 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2077 (void)res; 2078 auto map = 2079 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2080 map = permutationMap ? map.compose(permutationMap) : map; 2081 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2082 } 2083 2084 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2085 AffineMapAttr permutation, 2086 ArrayRef<NamedAttribute> attrs) { 2087 auto permutationMap = permutation.getValue(); 2088 assert(permutationMap); 2089 2090 auto memRefType = in.getType().cast<MemRefType>(); 2091 // Compute result type. 2092 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2093 2094 build(b, result, resultType, in, attrs); 2095 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2096 } 2097 2098 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2099 static void print(OpAsmPrinter &p, TransposeOp op) { 2100 p << "memref.transpose " << op.in() << " " << op.permutation(); 2101 p.printOptionalAttrDict(op->getAttrs(), 2102 {TransposeOp::getPermutationAttrName()}); 2103 p << " : " << op.in().getType() << " to " << op.getType(); 2104 } 2105 2106 static ParseResult parseTransposeOp(OpAsmParser &parser, 2107 OperationState &result) { 2108 OpAsmParser::OperandType in; 2109 AffineMap permutation; 2110 MemRefType srcType, dstType; 2111 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2112 parser.parseOptionalAttrDict(result.attributes) || 2113 parser.parseColonType(srcType) || 2114 parser.resolveOperand(in, srcType, result.operands) || 2115 parser.parseKeywordType("to", dstType) || 2116 parser.addTypeToList(dstType, result.types)) 2117 return failure(); 2118 2119 result.addAttribute(TransposeOp::getPermutationAttrName(), 2120 AffineMapAttr::get(permutation)); 2121 return success(); 2122 } 2123 2124 static LogicalResult verify(TransposeOp op) { 2125 if (!op.permutation().isPermutation()) 2126 return op.emitOpError("expected a permutation map"); 2127 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2128 return op.emitOpError( 2129 "expected a permutation map of same rank as the input"); 2130 2131 auto srcType = op.in().getType().cast<MemRefType>(); 2132 auto dstType = op.getType().cast<MemRefType>(); 2133 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2134 if (dstType != transposedType) 2135 return op.emitOpError("output type ") 2136 << dstType << " does not match transposed input type " << srcType 2137 << ", " << transposedType; 2138 return success(); 2139 } 2140 2141 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2142 if (succeeded(foldMemRefCast(*this))) 2143 return getResult(); 2144 return {}; 2145 } 2146 2147 //===----------------------------------------------------------------------===// 2148 // ViewOp 2149 //===----------------------------------------------------------------------===// 2150 2151 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2152 OpAsmParser::OperandType srcInfo; 2153 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2154 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2155 auto indexType = parser.getBuilder().getIndexType(); 2156 Type srcType, dstType; 2157 llvm::SMLoc offsetLoc; 2158 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2159 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2160 return failure(); 2161 2162 if (offsetInfo.size() != 1) 2163 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2164 2165 return failure( 2166 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2167 parser.parseOptionalAttrDict(result.attributes) || 2168 parser.parseColonType(srcType) || 2169 parser.resolveOperand(srcInfo, srcType, result.operands) || 2170 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2171 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2172 parser.parseKeywordType("to", dstType) || 2173 parser.addTypeToList(dstType, result.types)); 2174 } 2175 2176 static void print(OpAsmPrinter &p, ViewOp op) { 2177 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2178 p.printOperand(op.byte_shift()); 2179 p << "][" << op.sizes() << ']'; 2180 p.printOptionalAttrDict(op->getAttrs()); 2181 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2182 } 2183 2184 static LogicalResult verify(ViewOp op) { 2185 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2186 auto viewType = op.getType(); 2187 2188 // The base memref should have identity layout map (or none). 2189 if (baseType.getAffineMaps().size() > 1 || 2190 (baseType.getAffineMaps().size() == 1 && 2191 !baseType.getAffineMaps()[0].isIdentity())) 2192 return op.emitError("unsupported map for base memref type ") << baseType; 2193 2194 // The result memref should have identity layout map (or none). 2195 if (viewType.getAffineMaps().size() > 1 || 2196 (viewType.getAffineMaps().size() == 1 && 2197 !viewType.getAffineMaps()[0].isIdentity())) 2198 return op.emitError("unsupported map for result memref type ") << viewType; 2199 2200 // The base memref and the view memref should be in the same memory space. 2201 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2202 return op.emitError("different memory spaces specified for base memref " 2203 "type ") 2204 << baseType << " and view memref type " << viewType; 2205 2206 // Verify that we have the correct number of sizes for the result type. 2207 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2208 if (op.sizes().size() != numDynamicDims) 2209 return op.emitError("incorrect number of size operands for type ") 2210 << viewType; 2211 2212 return success(); 2213 } 2214 2215 Value ViewOp::getViewSource() { return source(); } 2216 2217 namespace { 2218 2219 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2220 using OpRewritePattern<ViewOp>::OpRewritePattern; 2221 2222 LogicalResult matchAndRewrite(ViewOp viewOp, 2223 PatternRewriter &rewriter) const override { 2224 // Return if none of the operands are constants. 2225 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2226 return matchPattern(operand, matchConstantIndex()); 2227 })) 2228 return failure(); 2229 2230 // Get result memref type. 2231 auto memrefType = viewOp.getType(); 2232 2233 // Get offset from old memref view type 'memRefType'. 2234 int64_t oldOffset; 2235 SmallVector<int64_t, 4> oldStrides; 2236 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2237 return failure(); 2238 assert(oldOffset == 0 && "Expected 0 offset"); 2239 2240 SmallVector<Value, 4> newOperands; 2241 2242 // Offset cannot be folded into result type. 2243 2244 // Fold any dynamic dim operands which are produced by a constant. 2245 SmallVector<int64_t, 4> newShapeConstants; 2246 newShapeConstants.reserve(memrefType.getRank()); 2247 2248 unsigned dynamicDimPos = 0; 2249 unsigned rank = memrefType.getRank(); 2250 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2251 int64_t dimSize = memrefType.getDimSize(dim); 2252 // If this is already static dimension, keep it. 2253 if (!ShapedType::isDynamic(dimSize)) { 2254 newShapeConstants.push_back(dimSize); 2255 continue; 2256 } 2257 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2258 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2259 // Dynamic shape dimension will be folded. 2260 newShapeConstants.push_back(constantIndexOp.getValue()); 2261 } else { 2262 // Dynamic shape dimension not folded; copy operand from old memref. 2263 newShapeConstants.push_back(dimSize); 2264 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2265 } 2266 dynamicDimPos++; 2267 } 2268 2269 // Create new memref type with constant folded dims. 2270 MemRefType newMemRefType = 2271 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2272 // Nothing new, don't fold. 2273 if (newMemRefType == memrefType) 2274 return failure(); 2275 2276 // Create new ViewOp. 2277 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2278 viewOp.getOperand(0), 2279 viewOp.byte_shift(), newOperands); 2280 // Insert a cast so we have the same type as the old memref type. 2281 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2282 return success(); 2283 } 2284 }; 2285 2286 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2287 using OpRewritePattern<ViewOp>::OpRewritePattern; 2288 2289 LogicalResult matchAndRewrite(ViewOp viewOp, 2290 PatternRewriter &rewriter) const override { 2291 Value memrefOperand = viewOp.getOperand(0); 2292 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2293 if (!memrefCastOp) 2294 return failure(); 2295 Value allocOperand = memrefCastOp.getOperand(); 2296 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2297 if (!allocOp) 2298 return failure(); 2299 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2300 viewOp.byte_shift(), viewOp.sizes()); 2301 return success(); 2302 } 2303 }; 2304 2305 } // end anonymous namespace 2306 2307 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2308 MLIRContext *context) { 2309 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2310 } 2311 2312 //===----------------------------------------------------------------------===// 2313 // TableGen'd op method definitions 2314 //===----------------------------------------------------------------------===// 2315 2316 #define GET_OP_CLASSES 2317 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2318