1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/IR/AffineMap.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Interfaces/InferTypeOpInterface.h" 21 #include "mlir/Interfaces/ViewLikeInterface.h" 22 #include "llvm/ADT/STLExtras.h" 23 24 using namespace mlir; 25 using namespace mlir::memref; 26 27 /// Materialize a single constant operation from a given attribute value with 28 /// the desired resultant type. 29 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 30 Attribute value, Type type, 31 Location loc) { 32 return builder.create<mlir::ConstantOp>(loc, type, value); 33 } 34 35 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 36 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 37 return llvm::to_vector<4>( 38 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 39 return a.cast<IntegerAttr>().getInt(); 40 })); 41 } 42 43 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 44 /// it is a Value or into `staticVec` if it is an IntegerAttr. 45 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 46 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 47 /// come from an AttrSizedOperandSegments trait. 48 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 49 SmallVectorImpl<Value> &dynamicVec, 50 SmallVectorImpl<int64_t> &staticVec, 51 int64_t sentinel) { 52 if (auto v = ofr.dyn_cast<Value>()) { 53 dynamicVec.push_back(v); 54 staticVec.push_back(sentinel); 55 return; 56 } 57 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 58 staticVec.push_back(apInt.getSExtValue()); 59 } 60 61 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 62 SmallVectorImpl<Value> &dynamicVec, 63 SmallVectorImpl<int64_t> &staticVec, 64 int64_t sentinel) { 65 for (auto ofr : ofrs) 66 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Common canonicalization pattern support logic 71 //===----------------------------------------------------------------------===// 72 73 /// This is a common class used for patterns of the form 74 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 75 /// into the root operation directly. 76 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { 77 bool folded = false; 78 for (OpOperand &operand : op->getOpOperands()) { 79 auto cast = operand.get().getDefiningOp<CastOp>(); 80 if (cast && operand.get() != inner && 81 !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 82 operand.set(cast.getOperand()); 83 folded = true; 84 } 85 } 86 return success(folded); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Helpers for GlobalOp 91 //===----------------------------------------------------------------------===// 92 93 static Type getTensorTypeFromMemRefType(Type type) { 94 if (auto memref = type.dyn_cast<MemRefType>()) 95 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 96 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 97 return UnrankedTensorType::get(memref.getElementType()); 98 return NoneType::get(type.getContext()); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // AllocOp / AllocaOp 103 //===----------------------------------------------------------------------===// 104 105 template <typename AllocLikeOp> 106 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 107 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 108 "applies to only alloc or alloca"); 109 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 110 if (!memRefType) 111 return op.emitOpError("result must be a memref"); 112 113 if (static_cast<int64_t>(op.dynamicSizes().size()) != 114 memRefType.getNumDynamicDims()) 115 return op.emitOpError("dimension operand count does not equal memref " 116 "dynamic dimension count"); 117 118 unsigned numSymbols = 0; 119 if (!memRefType.getAffineMaps().empty()) 120 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 121 if (op.symbolOperands().size() != numSymbols) 122 return op.emitOpError( 123 "symbol operand count does not equal memref symbol count"); 124 125 return success(); 126 } 127 128 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 129 130 static LogicalResult verify(AllocaOp op) { 131 // An alloca op needs to have an ancestor with an allocation scope trait. 132 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 133 return op.emitOpError( 134 "requires an ancestor op with AutomaticAllocationScope trait"); 135 136 return verifyAllocLikeOp(op); 137 } 138 139 namespace { 140 /// Fold constant dimensions into an alloc like operation. 141 template <typename AllocLikeOp> 142 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 143 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 144 145 LogicalResult matchAndRewrite(AllocLikeOp alloc, 146 PatternRewriter &rewriter) const override { 147 // Check to see if any dimensions operands are constants. If so, we can 148 // substitute and drop them. 149 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 150 return matchPattern(operand, matchConstantIndex()); 151 })) 152 return failure(); 153 154 auto memrefType = alloc.getType(); 155 156 // Ok, we have one or more constant operands. Collect the non-constant ones 157 // and keep track of the resultant memref type to build. 158 SmallVector<int64_t, 4> newShapeConstants; 159 newShapeConstants.reserve(memrefType.getRank()); 160 SmallVector<Value, 4> newOperands; 161 162 unsigned dynamicDimPos = 0; 163 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 164 int64_t dimSize = memrefType.getDimSize(dim); 165 // If this is already static dimension, keep it. 166 if (dimSize != -1) { 167 newShapeConstants.push_back(dimSize); 168 continue; 169 } 170 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 171 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 172 // Dynamic shape dimension will be folded. 173 newShapeConstants.push_back(constantIndexOp.getValue()); 174 } else { 175 // Dynamic shape dimension not folded; copy operand from old memref. 176 newShapeConstants.push_back(-1); 177 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 178 } 179 dynamicDimPos++; 180 } 181 182 // Create new memref type (which will have fewer dynamic dimensions). 183 MemRefType newMemRefType = 184 MemRefType::Builder(memrefType).setShape(newShapeConstants); 185 assert(static_cast<int64_t>(newOperands.size()) == 186 newMemRefType.getNumDynamicDims()); 187 188 // Create and insert the alloc op for the new memref. 189 auto newAlloc = rewriter.create<AllocLikeOp>( 190 alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); 191 // Insert a cast so we have the same type as the old alloc. 192 auto resultCast = 193 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 194 195 rewriter.replaceOp(alloc, {resultCast}); 196 return success(); 197 } 198 }; 199 200 /// Fold alloc operations with no users or only store and dealloc uses. 201 template <typename T> 202 struct SimplifyDeadAlloc : public OpRewritePattern<T> { 203 using OpRewritePattern<T>::OpRewritePattern; 204 205 LogicalResult matchAndRewrite(T alloc, 206 PatternRewriter &rewriter) const override { 207 if (llvm::any_of(alloc->getUsers(), [](Operation *op) { 208 return !isa<StoreOp, DeallocOp>(op); 209 })) 210 return failure(); 211 212 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers())) 213 rewriter.eraseOp(user); 214 215 rewriter.eraseOp(alloc); 216 return success(); 217 } 218 }; 219 } // end anonymous namespace. 220 221 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 222 MLIRContext *context) { 223 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context); 224 } 225 226 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 227 MLIRContext *context) { 228 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>( 229 context); 230 } 231 232 //===----------------------------------------------------------------------===// 233 // AllocaScopeOp 234 //===----------------------------------------------------------------------===// 235 236 static void print(OpAsmPrinter &p, AllocaScopeOp &op) { 237 bool printBlockTerminators = false; 238 239 p << AllocaScopeOp::getOperationName() << " "; 240 if (!op.results().empty()) { 241 p << " -> (" << op.getResultTypes() << ")"; 242 printBlockTerminators = true; 243 } 244 p.printRegion(op.bodyRegion(), 245 /*printEntryBlockArgs=*/false, 246 /*printBlockTerminators=*/printBlockTerminators); 247 p.printOptionalAttrDict(op->getAttrs()); 248 } 249 250 static ParseResult parseAllocaScopeOp(OpAsmParser &parser, 251 OperationState &result) { 252 // Create a region for the body. 253 result.regions.reserve(1); 254 Region *bodyRegion = result.addRegion(); 255 256 // Parse optional results type list. 257 if (parser.parseOptionalArrowTypeList(result.types)) 258 return failure(); 259 260 // Parse the body region. 261 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{})) 262 return failure(); 263 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(), 264 result.location); 265 266 // Parse the optional attribute list. 267 if (parser.parseOptionalAttrDict(result.attributes)) 268 return failure(); 269 270 return success(); 271 } 272 273 static LogicalResult verify(AllocaScopeOp op) { 274 if (failed(RegionBranchOpInterface::verifyTypes(op))) 275 return failure(); 276 277 return success(); 278 } 279 280 void AllocaScopeOp::getSuccessorRegions( 281 Optional<unsigned> index, ArrayRef<Attribute> operands, 282 SmallVectorImpl<RegionSuccessor> ®ions) { 283 if (index.hasValue()) { 284 regions.push_back(RegionSuccessor(getResults())); 285 return; 286 } 287 288 regions.push_back(RegionSuccessor(&bodyRegion())); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // AssumeAlignmentOp 293 //===----------------------------------------------------------------------===// 294 295 static LogicalResult verify(AssumeAlignmentOp op) { 296 unsigned alignment = op.alignment(); 297 if (!llvm::isPowerOf2_32(alignment)) 298 return op.emitOpError("alignment must be power of 2"); 299 return success(); 300 } 301 302 //===----------------------------------------------------------------------===// 303 // BufferCastOp 304 //===----------------------------------------------------------------------===// 305 306 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 307 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 308 if (tensorLoad.memref().getType() == getType()) 309 return tensorLoad.memref(); 310 return {}; 311 } 312 313 namespace { 314 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 315 struct BufferCast : public OpRewritePattern<BufferCastOp> { 316 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 317 318 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 319 PatternRewriter &rewriter) const final { 320 auto tensorCastOperand = 321 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 322 if (!tensorCastOperand) 323 return failure(); 324 auto srcTensorType = 325 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 326 if (!srcTensorType) 327 return failure(); 328 auto memrefType = MemRefType::get(srcTensorType.getShape(), 329 srcTensorType.getElementType()); 330 Value memref = rewriter.create<BufferCastOp>( 331 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 332 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 333 memref); 334 return success(); 335 } 336 }; 337 338 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 339 /// type mismatches prevent `BufferCastOp::fold` to kick in. 340 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 341 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 342 343 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 344 PatternRewriter &rewriter) const final { 345 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 346 // Bail unless we have a tensor_load + memref.buffer_cast with different 347 // types. `BufferCastOp::fold` handles the same type case. 348 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 349 return failure(); 350 // If types are not cast-compatible, bail. 351 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 352 bufferCast.getType())) 353 return failure(); 354 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 355 tensorLoad.memref()); 356 return success(); 357 } 358 }; 359 360 } // namespace 361 362 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 363 MLIRContext *context) { 364 results.add<BufferCast, TensorLoadToMemRef>(context); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // CastOp 369 //===----------------------------------------------------------------------===// 370 371 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 372 /// source memref. This is useful to to fold a memref.cast into a consuming op 373 /// and implement canonicalization patterns for ops in different dialects that 374 /// may consume the results of memref.cast operations. Such foldable memref.cast 375 /// operations are typically inserted as `view` and `subview` ops are 376 /// canonicalized, to preserve the type compatibility of their uses. 377 /// 378 /// Returns true when all conditions are met: 379 /// 1. source and result are ranked memrefs with strided semantics and same 380 /// element type and rank. 381 /// 2. each of the source's size, offset or stride has more static information 382 /// than the corresponding result's size, offset or stride. 383 /// 384 /// Example 1: 385 /// ```mlir 386 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 387 /// %2 = consumer %1 ... : memref<?x?xf32> ... 388 /// ``` 389 /// 390 /// may fold into: 391 /// 392 /// ```mlir 393 /// %2 = consumer %0 ... : memref<8x16xf32> ... 394 /// ``` 395 /// 396 /// Example 2: 397 /// ``` 398 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 399 /// to memref<?x?xf32> 400 /// consumer %1 : memref<?x?xf32> ... 401 /// ``` 402 /// 403 /// may fold into: 404 /// 405 /// ``` 406 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 407 /// ``` 408 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 409 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 410 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 411 412 // Requires ranked MemRefType. 413 if (!sourceType || !resultType) 414 return false; 415 416 // Requires same elemental type. 417 if (sourceType.getElementType() != resultType.getElementType()) 418 return false; 419 420 // Requires same rank. 421 if (sourceType.getRank() != resultType.getRank()) 422 return false; 423 424 // Only fold casts between strided memref forms. 425 int64_t sourceOffset, resultOffset; 426 SmallVector<int64_t, 4> sourceStrides, resultStrides; 427 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 428 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 429 return false; 430 431 // If cast is towards more static sizes along any dimension, don't fold. 432 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 433 auto ss = std::get<0>(it), st = std::get<1>(it); 434 if (ss != st) 435 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 436 return false; 437 } 438 439 // If cast is towards more static offset along any dimension, don't fold. 440 if (sourceOffset != resultOffset) 441 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 442 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 443 return false; 444 445 // If cast is towards more static strides along any dimension, don't fold. 446 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 447 auto ss = std::get<0>(it), st = std::get<1>(it); 448 if (ss != st) 449 if (MemRefType::isDynamicStrideOrOffset(ss) && 450 !MemRefType::isDynamicStrideOrOffset(st)) 451 return false; 452 } 453 454 return true; 455 } 456 457 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 458 if (inputs.size() != 1 || outputs.size() != 1) 459 return false; 460 Type a = inputs.front(), b = outputs.front(); 461 auto aT = a.dyn_cast<MemRefType>(); 462 auto bT = b.dyn_cast<MemRefType>(); 463 464 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 465 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 466 467 if (aT && bT) { 468 if (aT.getElementType() != bT.getElementType()) 469 return false; 470 if (aT.getAffineMaps() != bT.getAffineMaps()) { 471 int64_t aOffset, bOffset; 472 SmallVector<int64_t, 4> aStrides, bStrides; 473 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 474 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 475 aStrides.size() != bStrides.size()) 476 return false; 477 478 // Strides along a dimension/offset are compatible if the value in the 479 // source memref is static and the value in the target memref is the 480 // same. They are also compatible if either one is dynamic (see 481 // description of MemRefCastOp for details). 482 auto checkCompatible = [](int64_t a, int64_t b) { 483 return (a == MemRefType::getDynamicStrideOrOffset() || 484 b == MemRefType::getDynamicStrideOrOffset() || a == b); 485 }; 486 if (!checkCompatible(aOffset, bOffset)) 487 return false; 488 for (auto aStride : enumerate(aStrides)) 489 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 490 return false; 491 } 492 if (aT.getMemorySpace() != bT.getMemorySpace()) 493 return false; 494 495 // They must have the same rank, and any specified dimensions must match. 496 if (aT.getRank() != bT.getRank()) 497 return false; 498 499 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 500 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 501 if (aDim != -1 && bDim != -1 && aDim != bDim) 502 return false; 503 } 504 return true; 505 } else { 506 if (!aT && !uaT) 507 return false; 508 if (!bT && !ubT) 509 return false; 510 // Unranked to unranked casting is unsupported 511 if (uaT && ubT) 512 return false; 513 514 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 515 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 516 if (aEltType != bEltType) 517 return false; 518 519 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 520 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 521 if (aMemSpace != bMemSpace) 522 return false; 523 524 return true; 525 } 526 527 return false; 528 } 529 530 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 531 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 532 } 533 534 //===----------------------------------------------------------------------===// 535 // CloneOp 536 //===----------------------------------------------------------------------===// 537 538 void CloneOp::getEffects( 539 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 540 &effects) { 541 effects.emplace_back(MemoryEffects::Read::get(), input(), 542 SideEffects::DefaultResource::get()); 543 effects.emplace_back(MemoryEffects::Write::get(), output(), 544 SideEffects::DefaultResource::get()); 545 } 546 547 namespace { 548 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 549 /// by other Dealloc operations. 550 struct SimplifyClones : public OpRewritePattern<CloneOp> { 551 using OpRewritePattern<CloneOp>::OpRewritePattern; 552 553 LogicalResult matchAndRewrite(CloneOp cloneOp, 554 PatternRewriter &rewriter) const override { 555 if (cloneOp.use_empty()) { 556 rewriter.eraseOp(cloneOp); 557 return success(); 558 } 559 560 Value source = cloneOp.input(); 561 562 // This only finds dealloc operations for the immediate value. It should 563 // also consider aliases. That would also make the safety check below 564 // redundant. 565 Operation *cloneDeallocOp = findDealloc(cloneOp.output()); 566 Operation *sourceDeallocOp = findDealloc(source); 567 568 // If both are deallocated in the same block, their in-block lifetimes 569 // might not fully overlap, so we cannot decide which one to drop. 570 if (cloneDeallocOp && sourceDeallocOp && 571 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock()) 572 return failure(); 573 574 Block *currentBlock = cloneOp->getBlock(); 575 Operation *redundantDealloc = nullptr; 576 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) { 577 redundantDealloc = cloneDeallocOp; 578 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) { 579 redundantDealloc = sourceDeallocOp; 580 } 581 582 if (!redundantDealloc) 583 return failure(); 584 585 // Safety check that there are no other deallocations inbetween 586 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias 587 // of source before the uses of the clone. With alias information, we could 588 // restrict this to only fail of the dealloc's operand is an alias 589 // of the source. 590 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc; 591 pos = pos->getNextNode()) { 592 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos); 593 if (!effectInterface) 594 continue; 595 if (effectInterface.hasEffect<MemoryEffects::Free>()) 596 return failure(); 597 } 598 599 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(), 600 source); 601 rewriter.eraseOp(redundantDealloc); 602 return success(); 603 } 604 }; 605 606 } // end anonymous namespace. 607 608 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 609 MLIRContext *context) { 610 results.insert<SimplifyClones>(context); 611 } 612 613 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { 614 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 615 } 616 617 //===----------------------------------------------------------------------===// 618 // DeallocOp 619 //===----------------------------------------------------------------------===// 620 621 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 622 SmallVectorImpl<OpFoldResult> &results) { 623 /// dealloc(memrefcast) -> dealloc 624 return foldMemRefCast(*this); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // DimOp 629 //===----------------------------------------------------------------------===// 630 631 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 632 int64_t index) { 633 auto loc = result.location; 634 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 635 build(builder, result, memref, indexValue); 636 } 637 638 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 639 Value index) { 640 auto indexTy = builder.getIndexType(); 641 build(builder, result, indexTy, memref, index); 642 } 643 644 Optional<int64_t> DimOp::getConstantIndex() { 645 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 646 return constantOp.getValue().cast<IntegerAttr>().getInt(); 647 return {}; 648 } 649 650 static LogicalResult verify(DimOp op) { 651 // Assume unknown index to be in range. 652 Optional<int64_t> index = op.getConstantIndex(); 653 if (!index.hasValue()) 654 return success(); 655 656 // Check that constant index is not knowingly out of range. 657 auto type = op.memrefOrTensor().getType(); 658 if (auto memrefType = type.dyn_cast<MemRefType>()) { 659 if (index.getValue() >= memrefType.getRank()) 660 return op.emitOpError("index is out of range"); 661 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 662 if (index.getValue() >= tensorType.getRank()) 663 return op.emitOpError("index is out of range"); 664 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 665 // Assume index to be in range. 666 } else { 667 llvm_unreachable("expected operand with memref type"); 668 } 669 return success(); 670 } 671 672 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 673 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 674 675 // All forms of folding require a known index. 676 if (!index) 677 return {}; 678 679 auto argTy = memrefOrTensor().getType(); 680 // Fold if the shape extent along the given index is known. 681 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 682 // Folding for unranked types (UnrankedMemRefType) is not supported. 683 if (!shapedTy.hasRank()) 684 return {}; 685 if (!shapedTy.isDynamicDim(index.getInt())) { 686 Builder builder(getContext()); 687 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 688 } 689 } 690 691 Operation *definingOp = memrefOrTensor().getDefiningOp(); 692 693 // dim(memref.tensor_load(memref)) -> dim(memref) 694 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 695 setOperand(0, tensorLoadOp.memref()); 696 return getResult(); 697 } 698 699 // Fold dim to the operand of tensor.generate. 700 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 701 auto resultType = 702 fromElements.getResult().getType().cast<RankedTensorType>(); 703 // The case where the type encodes the size of the dimension is handled 704 // above. 705 assert(resultType.getShape()[index.getInt()] == 706 RankedTensorType::kDynamicSize); 707 708 // Find the operand of the fromElements that corresponds to this index. 709 auto dynExtents = fromElements.dynamicExtents().begin(); 710 for (auto dim : resultType.getShape().take_front(index.getInt())) 711 if (dim == RankedTensorType::kDynamicSize) 712 dynExtents++; 713 714 return Value{*dynExtents}; 715 } 716 717 // The size at the given index is now known to be a dynamic size. 718 unsigned unsignedIndex = index.getValue().getZExtValue(); 719 720 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 721 assert(subtensor.isDynamicSize(unsignedIndex) && 722 "Expected dynamic subtensor size"); 723 return subtensor.getDynamicSize(unsignedIndex); 724 } 725 726 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 727 auto memrefType = argTy.dyn_cast<MemRefType>(); 728 if (!memrefType) 729 return {}; 730 731 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 732 return *(alloc.getDynamicSizes().begin() + 733 memrefType.getDynamicDimIndex(unsignedIndex)); 734 735 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 736 return *(alloca.getDynamicSizes().begin() + 737 memrefType.getDynamicDimIndex(unsignedIndex)); 738 739 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 740 return *(view.getDynamicSizes().begin() + 741 memrefType.getDynamicDimIndex(unsignedIndex)); 742 743 if (auto sizeInterface = 744 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { 745 assert(sizeInterface.isDynamicSize(unsignedIndex) && 746 "Expected dynamic subview size"); 747 return sizeInterface.getDynamicSize(unsignedIndex); 748 } 749 750 // dim(memrefcast) -> dim 751 if (succeeded(foldMemRefCast(*this))) 752 return getResult(); 753 754 return {}; 755 } 756 757 namespace { 758 /// Fold dim of a memref reshape operation to a load into the reshape's shape 759 /// operand. 760 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 761 using OpRewritePattern<DimOp>::OpRewritePattern; 762 763 LogicalResult matchAndRewrite(DimOp dim, 764 PatternRewriter &rewriter) const override { 765 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 766 767 if (!reshape) 768 return failure(); 769 770 // Place the load directly after the reshape to ensure that the shape memref 771 // was not mutated. 772 rewriter.setInsertionPointAfter(reshape); 773 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(), 774 llvm::makeArrayRef({dim.index()})); 775 return success(); 776 } 777 }; 778 779 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 780 template <typename CastOpTy> 781 struct DimOfCastOp : public OpRewritePattern<DimOp> { 782 using OpRewritePattern<DimOp>::OpRewritePattern; 783 784 LogicalResult matchAndRewrite(DimOp dimOp, 785 PatternRewriter &rewriter) const override { 786 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 787 if (!castOp) 788 return failure(); 789 Value newSource = castOp.getOperand(); 790 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 791 return success(); 792 } 793 }; 794 795 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th 796 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. 797 /// TODO(ravishankarm): This is better put as a interface utility method 798 /// somewhere, but that would imply the interface will depend on the `tensor` 799 /// dialect. Ideally maybe a utility method in the `tensor` dialect. 800 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, 801 int64_t dimIndex) { 802 unsigned resultNumber = result.getResultNumber(); 803 auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); 804 Location loc = result.getOwner()->getLoc(); 805 if (!shapedTypeOp) 806 return nullptr; 807 808 // The interface exposes two methods, one that returns the shape of all the 809 // results as `Value` and other that returns the shape as a list of 810 // `SmallVector<Value>`. The former takes precedence over the latter. So first 811 // check if the op implements the first interface method or the second, and 812 // get the value to use appropriately. 813 SmallVector<Value> reifiedResultShapes; 814 if (succeeded(shapedTypeOp.reifyReturnTypeShapes( 815 builder, result.getOwner()->getOperands(), reifiedResultShapes))) { 816 if (reifiedResultShapes.size() <= resultNumber) 817 return nullptr; 818 Value resultShape = reifiedResultShapes[resultNumber]; 819 auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); 820 if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) 821 return nullptr; 822 return builder.create<tensor::ExtractOp>( 823 loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); 824 } 825 826 SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; 827 if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( 828 builder, reifiedResultShapesPerDim))) 829 return nullptr; 830 if (reifiedResultShapesPerDim.size() <= resultNumber || 831 reifiedResultShapesPerDim[resultNumber].size() != 832 static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) 833 return nullptr; 834 OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; 835 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) 836 return builder.createOrFold<ConstantIndexOp>( 837 loc, attr.cast<IntegerAttr>().getInt()); 838 return valueOrAttr.get<Value>(); 839 } 840 841 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 842 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> { 843 using OpRewritePattern<DimOp>::OpRewritePattern; 844 845 LogicalResult matchAndRewrite(DimOp dimOp, 846 PatternRewriter &rewriter) const override { 847 OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); 848 if (!dimValue) 849 return failure(); 850 auto shapedTypeOp = 851 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 852 if (!shapedTypeOp) 853 return failure(); 854 855 Optional<int64_t> dimIndex = dimOp.getConstantIndex(); 856 if (!dimIndex) 857 return failure(); 858 Value replacement = 859 getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); 860 if (!replacement) 861 return failure(); 862 rewriter.replaceOp(dimOp, replacement); 863 return success(); 864 } 865 }; 866 } // end anonymous namespace. 867 868 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 869 MLIRContext *context) { 870 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 871 DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context); 872 } 873 874 // --------------------------------------------------------------------------- 875 // DmaStartOp 876 // --------------------------------------------------------------------------- 877 878 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 879 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 880 ValueRange destIndices, Value numElements, 881 Value tagMemRef, ValueRange tagIndices, Value stride, 882 Value elementsPerStride) { 883 result.addOperands(srcMemRef); 884 result.addOperands(srcIndices); 885 result.addOperands(destMemRef); 886 result.addOperands(destIndices); 887 result.addOperands({numElements, tagMemRef}); 888 result.addOperands(tagIndices); 889 if (stride) 890 result.addOperands({stride, elementsPerStride}); 891 } 892 893 void DmaStartOp::print(OpAsmPrinter &p) { 894 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 895 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 896 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 897 << ']'; 898 if (isStrided()) 899 p << ", " << getStride() << ", " << getNumElementsPerStride(); 900 901 p.printOptionalAttrDict((*this)->getAttrs()); 902 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 903 << ", " << getTagMemRef().getType(); 904 } 905 906 // Parse DmaStartOp. 907 // Ex: 908 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 909 // %tag[%index], %stride, %num_elt_per_stride : 910 // : memref<3076 x f32, 0>, 911 // memref<1024 x f32, 2>, 912 // memref<1 x i32> 913 // 914 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 915 OpAsmParser::OperandType srcMemRefInfo; 916 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 917 OpAsmParser::OperandType dstMemRefInfo; 918 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 919 OpAsmParser::OperandType numElementsInfo; 920 OpAsmParser::OperandType tagMemrefInfo; 921 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 922 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 923 924 SmallVector<Type, 3> types; 925 auto indexType = parser.getBuilder().getIndexType(); 926 927 // Parse and resolve the following list of operands: 928 // *) source memref followed by its indices (in square brackets). 929 // *) destination memref followed by its indices (in square brackets). 930 // *) dma size in KiB. 931 if (parser.parseOperand(srcMemRefInfo) || 932 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 933 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 934 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 935 parser.parseComma() || parser.parseOperand(numElementsInfo) || 936 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 937 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 938 return failure(); 939 940 // Parse optional stride and elements per stride. 941 if (parser.parseTrailingOperandList(strideInfo)) 942 return failure(); 943 944 bool isStrided = strideInfo.size() == 2; 945 if (!strideInfo.empty() && !isStrided) { 946 return parser.emitError(parser.getNameLoc(), 947 "expected two stride related operands"); 948 } 949 950 if (parser.parseColonTypeList(types)) 951 return failure(); 952 if (types.size() != 3) 953 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 954 955 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 956 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 957 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 958 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 959 // size should be an index. 960 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 961 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 962 // tag indices should be index. 963 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 964 return failure(); 965 966 if (isStrided) { 967 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 968 return failure(); 969 } 970 971 return success(); 972 } 973 974 LogicalResult DmaStartOp::verify() { 975 unsigned numOperands = getNumOperands(); 976 977 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 978 // the number of elements. 979 if (numOperands < 4) 980 return emitOpError("expected at least 4 operands"); 981 982 // Check types of operands. The order of these calls is important: the later 983 // calls rely on some type properties to compute the operand position. 984 // 1. Source memref. 985 if (!getSrcMemRef().getType().isa<MemRefType>()) 986 return emitOpError("expected source to be of memref type"); 987 if (numOperands < getSrcMemRefRank() + 4) 988 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 989 << " operands"; 990 if (!getSrcIndices().empty() && 991 !llvm::all_of(getSrcIndices().getTypes(), 992 [](Type t) { return t.isIndex(); })) 993 return emitOpError("expected source indices to be of index type"); 994 995 // 2. Destination memref. 996 if (!getDstMemRef().getType().isa<MemRefType>()) 997 return emitOpError("expected destination to be of memref type"); 998 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 999 if (numOperands < numExpectedOperands) 1000 return emitOpError() << "expected at least " << numExpectedOperands 1001 << " operands"; 1002 if (!getDstIndices().empty() && 1003 !llvm::all_of(getDstIndices().getTypes(), 1004 [](Type t) { return t.isIndex(); })) 1005 return emitOpError("expected destination indices to be of index type"); 1006 1007 // 3. Number of elements. 1008 if (!getNumElements().getType().isIndex()) 1009 return emitOpError("expected num elements to be of index type"); 1010 1011 // 4. Tag memref. 1012 if (!getTagMemRef().getType().isa<MemRefType>()) 1013 return emitOpError("expected tag to be of memref type"); 1014 numExpectedOperands += getTagMemRefRank(); 1015 if (numOperands < numExpectedOperands) 1016 return emitOpError() << "expected at least " << numExpectedOperands 1017 << " operands"; 1018 if (!getTagIndices().empty() && 1019 !llvm::all_of(getTagIndices().getTypes(), 1020 [](Type t) { return t.isIndex(); })) 1021 return emitOpError("expected tag indices to be of index type"); 1022 1023 // Optional stride-related operands must be either both present or both 1024 // absent. 1025 if (numOperands != numExpectedOperands && 1026 numOperands != numExpectedOperands + 2) 1027 return emitOpError("incorrect number of operands"); 1028 1029 // 5. Strides. 1030 if (isStrided()) { 1031 if (!getStride().getType().isIndex() || 1032 !getNumElementsPerStride().getType().isIndex()) 1033 return emitOpError( 1034 "expected stride and num elements per stride to be of type index"); 1035 } 1036 1037 return success(); 1038 } 1039 1040 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 1041 SmallVectorImpl<OpFoldResult> &results) { 1042 /// dma_start(memrefcast) -> dma_start 1043 return foldMemRefCast(*this); 1044 } 1045 1046 // --------------------------------------------------------------------------- 1047 // DmaWaitOp 1048 // --------------------------------------------------------------------------- 1049 1050 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 1051 Value tagMemRef, ValueRange tagIndices, 1052 Value numElements) { 1053 result.addOperands(tagMemRef); 1054 result.addOperands(tagIndices); 1055 result.addOperands(numElements); 1056 } 1057 1058 void DmaWaitOp::print(OpAsmPrinter &p) { 1059 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 1060 << "], " << getNumElements(); 1061 p.printOptionalAttrDict((*this)->getAttrs()); 1062 p << " : " << getTagMemRef().getType(); 1063 } 1064 1065 // Parse DmaWaitOp. 1066 // Eg: 1067 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 1068 // 1069 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 1070 OpAsmParser::OperandType tagMemrefInfo; 1071 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 1072 Type type; 1073 auto indexType = parser.getBuilder().getIndexType(); 1074 OpAsmParser::OperandType numElementsInfo; 1075 1076 // Parse tag memref, its indices, and dma size. 1077 if (parser.parseOperand(tagMemrefInfo) || 1078 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1079 parser.parseComma() || parser.parseOperand(numElementsInfo) || 1080 parser.parseColonType(type) || 1081 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 1082 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 1083 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 1084 return failure(); 1085 1086 return success(); 1087 } 1088 1089 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 1090 SmallVectorImpl<OpFoldResult> &results) { 1091 /// dma_wait(memrefcast) -> dma_wait 1092 return foldMemRefCast(*this); 1093 } 1094 1095 LogicalResult DmaWaitOp::verify() { 1096 // Mandatory non-variadic operands are tag and the number of elements. 1097 if (getNumOperands() < 2) 1098 return emitOpError() << "expected at least 2 operands"; 1099 1100 // Check types of operands. The order of these calls is important: the later 1101 // calls rely on some type properties to compute the operand position. 1102 if (!getTagMemRef().getType().isa<MemRefType>()) 1103 return emitOpError() << "expected tag to be of memref type"; 1104 1105 if (getNumOperands() != 2 + getTagMemRefRank()) 1106 return emitOpError() << "expected " << 2 + getTagMemRefRank() 1107 << " operands"; 1108 1109 if (!getTagIndices().empty() && 1110 !llvm::all_of(getTagIndices().getTypes(), 1111 [](Type t) { return t.isIndex(); })) 1112 return emitOpError() << "expected tag indices to be of index type"; 1113 1114 if (!getNumElements().getType().isIndex()) 1115 return emitOpError() 1116 << "expected the number of elements to be of index type"; 1117 1118 return success(); 1119 } 1120 1121 //===----------------------------------------------------------------------===// 1122 // GlobalOp 1123 //===----------------------------------------------------------------------===// 1124 1125 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1126 TypeAttr type, 1127 Attribute initialValue) { 1128 p << type; 1129 if (!op.isExternal()) { 1130 p << " = "; 1131 if (op.isUninitialized()) 1132 p << "uninitialized"; 1133 else 1134 p.printAttributeWithoutType(initialValue); 1135 } 1136 } 1137 1138 static ParseResult 1139 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1140 Attribute &initialValue) { 1141 Type type; 1142 if (parser.parseType(type)) 1143 return failure(); 1144 1145 auto memrefType = type.dyn_cast<MemRefType>(); 1146 if (!memrefType || !memrefType.hasStaticShape()) 1147 return parser.emitError(parser.getNameLoc()) 1148 << "type should be static shaped memref, but got " << type; 1149 typeAttr = TypeAttr::get(type); 1150 1151 if (parser.parseOptionalEqual()) 1152 return success(); 1153 1154 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1155 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1156 return success(); 1157 } 1158 1159 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1160 if (parser.parseAttribute(initialValue, tensorType)) 1161 return failure(); 1162 if (!initialValue.isa<ElementsAttr>()) 1163 return parser.emitError(parser.getNameLoc()) 1164 << "initial value should be a unit or elements attribute"; 1165 return success(); 1166 } 1167 1168 static LogicalResult verify(GlobalOp op) { 1169 auto memrefType = op.type().dyn_cast<MemRefType>(); 1170 if (!memrefType || !memrefType.hasStaticShape()) 1171 return op.emitOpError("type should be static shaped memref, but got ") 1172 << op.type(); 1173 1174 // Verify that the initial value, if present, is either a unit attribute or 1175 // an elements attribute. 1176 if (op.initial_value().hasValue()) { 1177 Attribute initValue = op.initial_value().getValue(); 1178 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1179 return op.emitOpError("initial value should be a unit or elements " 1180 "attribute, but got ") 1181 << initValue; 1182 1183 // Check that the type of the initial value is compatible with the type of 1184 // the global variable. 1185 if (initValue.isa<ElementsAttr>()) { 1186 Type initType = initValue.getType(); 1187 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1188 if (initType != tensorType) 1189 return op.emitOpError("initial value expected to be of type ") 1190 << tensorType << ", but was of type " << initType; 1191 } 1192 } 1193 1194 // TODO: verify visibility for declarations. 1195 return success(); 1196 } 1197 1198 //===----------------------------------------------------------------------===// 1199 // GetGlobalOp 1200 //===----------------------------------------------------------------------===// 1201 1202 LogicalResult 1203 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1204 // Verify that the result type is same as the type of the referenced 1205 // memref.global op. 1206 auto global = 1207 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1208 if (!global) 1209 return emitOpError("'") 1210 << name() << "' does not reference a valid global memref"; 1211 1212 Type resultType = result().getType(); 1213 if (global.type() != resultType) 1214 return emitOpError("result type ") 1215 << resultType << " does not match type " << global.type() 1216 << " of the global memref @" << name(); 1217 return success(); 1218 } 1219 1220 //===----------------------------------------------------------------------===// 1221 // LoadOp 1222 //===----------------------------------------------------------------------===// 1223 1224 static LogicalResult verify(LoadOp op) { 1225 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1226 return op.emitOpError("incorrect number of indices for load"); 1227 return success(); 1228 } 1229 1230 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1231 /// load(memrefcast) -> load 1232 if (succeeded(foldMemRefCast(*this))) 1233 return getResult(); 1234 return OpFoldResult(); 1235 } 1236 1237 namespace { 1238 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1239 /// corresponding tensor. 1240 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1241 using OpRewritePattern<LoadOp>::OpRewritePattern; 1242 1243 LogicalResult matchAndRewrite(LoadOp load, 1244 PatternRewriter &rewriter) const override { 1245 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1246 if (!buffercast) 1247 return failure(); 1248 1249 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1250 load.indices()); 1251 return success(); 1252 } 1253 }; 1254 } // end anonymous namespace. 1255 1256 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1257 MLIRContext *context) { 1258 results.add<LoadOfBufferCast>(context); 1259 } 1260 1261 //===----------------------------------------------------------------------===// 1262 // PrefetchOp 1263 //===----------------------------------------------------------------------===// 1264 1265 static void print(OpAsmPrinter &p, PrefetchOp op) { 1266 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1267 p.printOperands(op.indices()); 1268 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1269 p << ", locality<" << op.localityHint(); 1270 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1271 p.printOptionalAttrDict( 1272 op->getAttrs(), 1273 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1274 p << " : " << op.getMemRefType(); 1275 } 1276 1277 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1278 OperationState &result) { 1279 OpAsmParser::OperandType memrefInfo; 1280 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1281 IntegerAttr localityHint; 1282 MemRefType type; 1283 StringRef readOrWrite, cacheType; 1284 1285 auto indexTy = parser.getBuilder().getIndexType(); 1286 auto i32Type = parser.getBuilder().getIntegerType(32); 1287 if (parser.parseOperand(memrefInfo) || 1288 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1289 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1290 parser.parseComma() || parser.parseKeyword("locality") || 1291 parser.parseLess() || 1292 parser.parseAttribute(localityHint, i32Type, "localityHint", 1293 result.attributes) || 1294 parser.parseGreater() || parser.parseComma() || 1295 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1296 parser.resolveOperand(memrefInfo, type, result.operands) || 1297 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1298 return failure(); 1299 1300 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1301 return parser.emitError(parser.getNameLoc(), 1302 "rw specifier has to be 'read' or 'write'"); 1303 result.addAttribute( 1304 PrefetchOp::getIsWriteAttrName(), 1305 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1306 1307 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1308 return parser.emitError(parser.getNameLoc(), 1309 "cache type has to be 'data' or 'instr'"); 1310 1311 result.addAttribute( 1312 PrefetchOp::getIsDataCacheAttrName(), 1313 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1314 1315 return success(); 1316 } 1317 1318 static LogicalResult verify(PrefetchOp op) { 1319 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1320 return op.emitOpError("too few indices"); 1321 1322 return success(); 1323 } 1324 1325 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1326 SmallVectorImpl<OpFoldResult> &results) { 1327 // prefetch(memrefcast) -> prefetch 1328 return foldMemRefCast(*this); 1329 } 1330 1331 //===----------------------------------------------------------------------===// 1332 // ReinterpretCastOp 1333 //===----------------------------------------------------------------------===// 1334 1335 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1336 /// `staticSizes` and `staticStrides` are automatically filled with 1337 /// source-memref-rank sentinel values that encode dynamic entries. 1338 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1339 MemRefType resultType, Value source, 1340 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1341 ArrayRef<OpFoldResult> strides, 1342 ArrayRef<NamedAttribute> attrs) { 1343 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1344 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1345 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1346 ShapedType::kDynamicStrideOrOffset); 1347 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1348 ShapedType::kDynamicSize); 1349 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1350 ShapedType::kDynamicStrideOrOffset); 1351 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1352 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1353 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1354 result.addAttributes(attrs); 1355 } 1356 1357 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1358 MemRefType resultType, Value source, 1359 int64_t offset, ArrayRef<int64_t> sizes, 1360 ArrayRef<int64_t> strides, 1361 ArrayRef<NamedAttribute> attrs) { 1362 SmallVector<OpFoldResult> sizeValues = 1363 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1364 return b.getI64IntegerAttr(v); 1365 })); 1366 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1367 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1368 return b.getI64IntegerAttr(v); 1369 })); 1370 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1371 strideValues, attrs); 1372 } 1373 1374 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1375 MemRefType resultType, Value source, Value offset, 1376 ValueRange sizes, ValueRange strides, 1377 ArrayRef<NamedAttribute> attrs) { 1378 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1379 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1380 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1381 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1382 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1383 } 1384 1385 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1386 // completed automatically, like we have for subview and subtensor. 1387 static LogicalResult verify(ReinterpretCastOp op) { 1388 // The source and result memrefs should be in the same memory space. 1389 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1390 auto resultType = op.getType().cast<MemRefType>(); 1391 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1392 return op.emitError("different memory spaces specified for source type ") 1393 << srcType << " and result memref type " << resultType; 1394 if (srcType.getElementType() != resultType.getElementType()) 1395 return op.emitError("different element types specified for source type ") 1396 << srcType << " and result memref type " << resultType; 1397 1398 // Match sizes in result memref type and in static_sizes attribute. 1399 for (auto &en : 1400 llvm::enumerate(llvm::zip(resultType.getShape(), 1401 extractFromI64ArrayAttr(op.static_sizes())))) { 1402 int64_t resultSize = std::get<0>(en.value()); 1403 int64_t expectedSize = std::get<1>(en.value()); 1404 if (resultSize != expectedSize) 1405 return op.emitError("expected result type with size = ") 1406 << expectedSize << " instead of " << resultSize 1407 << " in dim = " << en.index(); 1408 } 1409 1410 // Match offset and strides in static_offset and static_strides attributes if 1411 // result memref type has an affine map specified. 1412 if (!resultType.getAffineMaps().empty()) { 1413 int64_t resultOffset; 1414 SmallVector<int64_t, 4> resultStrides; 1415 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1416 return failure(); 1417 1418 // Match offset in result memref type and in static_offsets attribute. 1419 int64_t expectedOffset = 1420 extractFromI64ArrayAttr(op.static_offsets()).front(); 1421 if (resultOffset != expectedOffset) 1422 return op.emitError("expected result type with offset = ") 1423 << resultOffset << " instead of " << expectedOffset; 1424 1425 // Match strides in result memref type and in static_strides attribute. 1426 for (auto &en : llvm::enumerate(llvm::zip( 1427 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1428 int64_t resultStride = std::get<0>(en.value()); 1429 int64_t expectedStride = std::get<1>(en.value()); 1430 if (resultStride != expectedStride) 1431 return op.emitError("expected result type with stride = ") 1432 << expectedStride << " instead of " << resultStride 1433 << " in dim = " << en.index(); 1434 } 1435 } 1436 return success(); 1437 } 1438 1439 //===----------------------------------------------------------------------===// 1440 // ReshapeOp 1441 //===----------------------------------------------------------------------===// 1442 1443 static LogicalResult verify(ReshapeOp op) { 1444 Type operandType = op.source().getType(); 1445 Type resultType = op.result().getType(); 1446 1447 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1448 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1449 if (operandElementType != resultElementType) 1450 return op.emitOpError("element types of source and destination memref " 1451 "types should be the same"); 1452 1453 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1454 if (!operandMemRefType.getAffineMaps().empty()) 1455 return op.emitOpError( 1456 "source memref type should have identity affine map"); 1457 1458 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1459 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1460 if (resultMemRefType) { 1461 if (!resultMemRefType.getAffineMaps().empty()) 1462 return op.emitOpError( 1463 "result memref type should have identity affine map"); 1464 if (shapeSize == ShapedType::kDynamicSize) 1465 return op.emitOpError("cannot use shape operand with dynamic length to " 1466 "reshape to statically-ranked memref type"); 1467 if (shapeSize != resultMemRefType.getRank()) 1468 return op.emitOpError( 1469 "length of shape operand differs from the result's memref rank"); 1470 } 1471 return success(); 1472 } 1473 1474 //===----------------------------------------------------------------------===// 1475 // StoreOp 1476 //===----------------------------------------------------------------------===// 1477 1478 static LogicalResult verify(StoreOp op) { 1479 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1480 return op.emitOpError("store index operand count not equal to memref rank"); 1481 1482 return success(); 1483 } 1484 1485 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1486 SmallVectorImpl<OpFoldResult> &results) { 1487 /// store(memrefcast) -> store 1488 return foldMemRefCast(*this, getValueToStore()); 1489 } 1490 1491 //===----------------------------------------------------------------------===// 1492 // SubViewOp 1493 //===----------------------------------------------------------------------===// 1494 1495 namespace { 1496 /// Helpers to write more idiomatic operations. 1497 namespace saturated_arith { 1498 struct Wrapper { 1499 explicit Wrapper(int64_t v) : v(v) {} 1500 operator int64_t() { return v; } 1501 int64_t v; 1502 }; 1503 Wrapper operator+(Wrapper a, int64_t b) { 1504 if (ShapedType::isDynamicStrideOrOffset(a) || 1505 ShapedType::isDynamicStrideOrOffset(b)) 1506 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1507 return Wrapper(a.v + b); 1508 } 1509 Wrapper operator*(Wrapper a, int64_t b) { 1510 if (ShapedType::isDynamicStrideOrOffset(a) || 1511 ShapedType::isDynamicStrideOrOffset(b)) 1512 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1513 return Wrapper(a.v * b); 1514 } 1515 } // end namespace saturated_arith 1516 } // end namespace 1517 1518 /// A subview result type can be fully inferred from the source type and the 1519 /// static representation of offsets, sizes and strides. Special sentinels 1520 /// encode the dynamic case. 1521 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1522 ArrayRef<int64_t> leadingStaticOffsets, 1523 ArrayRef<int64_t> leadingStaticSizes, 1524 ArrayRef<int64_t> leadingStaticStrides) { 1525 // A subview may specify only a leading subset of offset/sizes/strides in 1526 // which case we complete with offset=0, sizes from memref type and strides=1. 1527 unsigned rank = sourceMemRefType.getRank(); 1528 assert(leadingStaticOffsets.size() <= rank && 1529 "unexpected leadingStaticOffsets overflow"); 1530 assert(leadingStaticSizes.size() <= rank && 1531 "unexpected leadingStaticSizes overflow"); 1532 assert(leadingStaticStrides.size() <= rank && 1533 "unexpected leadingStaticStrides overflow"); 1534 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1535 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1536 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1537 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1538 unsigned numTrailingSizes = rank - staticSizes.size(); 1539 unsigned numTrailingStrides = rank - staticStrides.size(); 1540 staticOffsets.append(numTrailingOffsets, 0); 1541 llvm::append_range(staticSizes, 1542 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1543 staticStrides.append(numTrailingStrides, 1); 1544 1545 // Extract source offset and strides. 1546 int64_t sourceOffset; 1547 SmallVector<int64_t, 4> sourceStrides; 1548 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1549 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1550 (void)res; 1551 1552 // Compute target offset whose value is: 1553 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1554 int64_t targetOffset = sourceOffset; 1555 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1556 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1557 using namespace saturated_arith; 1558 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1559 } 1560 1561 // Compute target stride whose value is: 1562 // `sourceStrides_i * staticStrides_i`. 1563 SmallVector<int64_t, 4> targetStrides; 1564 targetStrides.reserve(staticOffsets.size()); 1565 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1566 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1567 using namespace saturated_arith; 1568 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1569 } 1570 1571 // The type is now known. 1572 return MemRefType::get( 1573 staticSizes, sourceMemRefType.getElementType(), 1574 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1575 sourceMemRefType.getContext()), 1576 sourceMemRefType.getMemorySpace()); 1577 } 1578 1579 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1580 ArrayRef<OpFoldResult> leadingStaticOffsets, 1581 ArrayRef<OpFoldResult> leadingStaticSizes, 1582 ArrayRef<OpFoldResult> leadingStaticStrides) { 1583 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1584 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1585 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1586 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1587 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1588 ShapedType::kDynamicSize); 1589 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1590 staticStrides, ShapedType::kDynamicStrideOrOffset); 1591 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1592 staticSizes, staticStrides) 1593 .cast<MemRefType>(); 1594 } 1595 1596 Type SubViewOp::inferRankReducedResultType( 1597 unsigned resultRank, MemRefType sourceRankedTensorType, 1598 ArrayRef<int64_t> leadingStaticOffsets, 1599 ArrayRef<int64_t> leadingStaticSizes, 1600 ArrayRef<int64_t> leadingStaticStrides) { 1601 auto inferredType = 1602 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1603 leadingStaticSizes, leadingStaticStrides) 1604 .cast<MemRefType>(); 1605 assert(inferredType.getRank() >= resultRank && "expected "); 1606 int rankDiff = inferredType.getRank() - resultRank; 1607 if (rankDiff > 0) { 1608 auto shape = inferredType.getShape(); 1609 llvm::SmallDenseSet<unsigned> dimsToProject; 1610 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1611 SmallVector<int64_t> projectedShape; 1612 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1613 if (!dimsToProject.contains(pos)) 1614 projectedShape.push_back(shape[pos]); 1615 1616 AffineMap map; 1617 auto maps = inferredType.getAffineMaps(); 1618 if (!maps.empty() && maps.front()) 1619 map = getProjectedMap(maps.front(), dimsToProject); 1620 inferredType = 1621 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1622 inferredType.getMemorySpace()); 1623 } 1624 return inferredType; 1625 } 1626 1627 Type SubViewOp::inferRankReducedResultType( 1628 unsigned resultRank, MemRefType sourceRankedTensorType, 1629 ArrayRef<OpFoldResult> leadingStaticOffsets, 1630 ArrayRef<OpFoldResult> leadingStaticSizes, 1631 ArrayRef<OpFoldResult> leadingStaticStrides) { 1632 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1633 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1634 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1635 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1636 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1637 ShapedType::kDynamicSize); 1638 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1639 staticStrides, ShapedType::kDynamicStrideOrOffset); 1640 return SubViewOp::inferRankReducedResultType( 1641 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1642 staticStrides); 1643 } 1644 // Build a SubViewOp with mixed static and dynamic entries and custom result 1645 // type. If the type passed is nullptr, it is inferred. 1646 void SubViewOp::build(OpBuilder &b, OperationState &result, 1647 MemRefType resultType, Value source, 1648 ArrayRef<OpFoldResult> offsets, 1649 ArrayRef<OpFoldResult> sizes, 1650 ArrayRef<OpFoldResult> strides, 1651 ArrayRef<NamedAttribute> attrs) { 1652 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1653 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1654 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1655 ShapedType::kDynamicStrideOrOffset); 1656 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1657 ShapedType::kDynamicSize); 1658 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1659 ShapedType::kDynamicStrideOrOffset); 1660 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1661 // Structuring implementation this way avoids duplication between builders. 1662 if (!resultType) { 1663 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1664 staticSizes, staticStrides) 1665 .cast<MemRefType>(); 1666 } 1667 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1668 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1669 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1670 result.addAttributes(attrs); 1671 } 1672 1673 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1674 // type. 1675 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1676 ArrayRef<OpFoldResult> offsets, 1677 ArrayRef<OpFoldResult> sizes, 1678 ArrayRef<OpFoldResult> strides, 1679 ArrayRef<NamedAttribute> attrs) { 1680 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1681 } 1682 1683 // Build a SubViewOp with static entries and inferred result type. 1684 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1685 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1686 ArrayRef<int64_t> strides, 1687 ArrayRef<NamedAttribute> attrs) { 1688 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1689 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1690 return b.getI64IntegerAttr(v); 1691 })); 1692 SmallVector<OpFoldResult> sizeValues = 1693 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1694 return b.getI64IntegerAttr(v); 1695 })); 1696 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1697 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1698 return b.getI64IntegerAttr(v); 1699 })); 1700 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1701 } 1702 1703 // Build a SubViewOp with dynamic entries and custom result type. If the 1704 // type passed is nullptr, it is inferred. 1705 void SubViewOp::build(OpBuilder &b, OperationState &result, 1706 MemRefType resultType, Value source, 1707 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1708 ArrayRef<int64_t> strides, 1709 ArrayRef<NamedAttribute> attrs) { 1710 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1711 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1712 return b.getI64IntegerAttr(v); 1713 })); 1714 SmallVector<OpFoldResult> sizeValues = 1715 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1716 return b.getI64IntegerAttr(v); 1717 })); 1718 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1719 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1720 return b.getI64IntegerAttr(v); 1721 })); 1722 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1723 attrs); 1724 } 1725 1726 // Build a SubViewOp with dynamic entries and custom result type. If the type 1727 // passed is nullptr, it is inferred. 1728 void SubViewOp::build(OpBuilder &b, OperationState &result, 1729 MemRefType resultType, Value source, ValueRange offsets, 1730 ValueRange sizes, ValueRange strides, 1731 ArrayRef<NamedAttribute> attrs) { 1732 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1733 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1734 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1735 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1736 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1737 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1738 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1739 } 1740 1741 // Build a SubViewOp with dynamic entries and inferred result type. 1742 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1743 ValueRange offsets, ValueRange sizes, ValueRange strides, 1744 ArrayRef<NamedAttribute> attrs) { 1745 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1746 } 1747 1748 /// For ViewLikeOpInterface. 1749 Value SubViewOp::getViewSource() { return source(); } 1750 1751 enum SubViewVerificationResult { 1752 Success, 1753 RankTooLarge, 1754 SizeMismatch, 1755 ElemTypeMismatch, 1756 MemSpaceMismatch, 1757 AffineMapMismatch 1758 }; 1759 1760 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1761 /// This function is slight variant of `is subsequence` algorithm where 1762 /// not matching dimension must be 1. 1763 static SubViewVerificationResult 1764 isRankReducedType(Type originalType, Type candidateReducedType, 1765 std::string *errMsg = nullptr) { 1766 if (originalType == candidateReducedType) 1767 return SubViewVerificationResult::Success; 1768 if (!originalType.isa<MemRefType>()) 1769 return SubViewVerificationResult::Success; 1770 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1771 return SubViewVerificationResult::Success; 1772 1773 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1774 ShapedType candidateReducedShapedType = 1775 candidateReducedType.cast<ShapedType>(); 1776 1777 // Rank and size logic is valid for all ShapedTypes. 1778 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1779 ArrayRef<int64_t> candidateReducedShape = 1780 candidateReducedShapedType.getShape(); 1781 unsigned originalRank = originalShape.size(), 1782 candidateReducedRank = candidateReducedShape.size(); 1783 if (candidateReducedRank > originalRank) 1784 return SubViewVerificationResult::RankTooLarge; 1785 1786 auto optionalUnusedDimsMask = 1787 computeRankReductionMask(originalShape, candidateReducedShape); 1788 1789 // Sizes cannot be matched in case empty vector is returned. 1790 if (!optionalUnusedDimsMask.hasValue()) 1791 return SubViewVerificationResult::SizeMismatch; 1792 1793 if (originalShapedType.getElementType() != 1794 candidateReducedShapedType.getElementType()) 1795 return SubViewVerificationResult::ElemTypeMismatch; 1796 1797 // Strided layout logic is relevant for MemRefType only. 1798 MemRefType original = originalType.cast<MemRefType>(); 1799 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1800 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1801 return SubViewVerificationResult::MemSpaceMismatch; 1802 1803 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1804 auto inferredType = 1805 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1806 AffineMap candidateLayout; 1807 if (candidateReduced.getAffineMaps().empty()) 1808 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1809 else 1810 candidateLayout = candidateReduced.getAffineMaps().front(); 1811 assert(inferredType.getNumResults() == 1 && 1812 candidateLayout.getNumResults() == 1); 1813 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1814 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1815 if (errMsg) { 1816 llvm::raw_string_ostream os(*errMsg); 1817 os << "inferred type: " << inferredType; 1818 } 1819 return SubViewVerificationResult::AffineMapMismatch; 1820 } 1821 // Check that the difference of the affine maps simplifies to 0. 1822 AffineExpr diffExpr = 1823 inferredType.getResult(0) - candidateLayout.getResult(0); 1824 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1825 inferredType.getNumSymbols()); 1826 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1827 if (!(cst && cst.getValue() == 0)) { 1828 if (errMsg) { 1829 llvm::raw_string_ostream os(*errMsg); 1830 os << "inferred type: " << inferredType; 1831 } 1832 return SubViewVerificationResult::AffineMapMismatch; 1833 } 1834 return SubViewVerificationResult::Success; 1835 } 1836 1837 template <typename OpTy> 1838 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1839 OpTy op, Type expectedType, 1840 StringRef errMsg = "") { 1841 auto memrefType = expectedType.cast<ShapedType>(); 1842 switch (result) { 1843 case SubViewVerificationResult::Success: 1844 return success(); 1845 case SubViewVerificationResult::RankTooLarge: 1846 return op.emitError("expected result rank to be smaller or equal to ") 1847 << "the source rank. " << errMsg; 1848 case SubViewVerificationResult::SizeMismatch: 1849 return op.emitError("expected result type to be ") 1850 << expectedType 1851 << " or a rank-reduced version. (mismatch of result sizes) " 1852 << errMsg; 1853 case SubViewVerificationResult::ElemTypeMismatch: 1854 return op.emitError("expected result element type to be ") 1855 << memrefType.getElementType() << errMsg; 1856 case SubViewVerificationResult::MemSpaceMismatch: 1857 return op.emitError("expected result and source memory spaces to match.") 1858 << errMsg; 1859 case SubViewVerificationResult::AffineMapMismatch: 1860 return op.emitError("expected result type to be ") 1861 << expectedType 1862 << " or a rank-reduced version. (mismatch of result affine map) " 1863 << errMsg; 1864 } 1865 llvm_unreachable("unexpected subview verification result"); 1866 } 1867 1868 /// Verifier for SubViewOp. 1869 static LogicalResult verify(SubViewOp op) { 1870 MemRefType baseType = op.getSourceType(); 1871 MemRefType subViewType = op.getType(); 1872 1873 // The base memref and the view memref should be in the same memory space. 1874 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1875 return op.emitError("different memory spaces specified for base memref " 1876 "type ") 1877 << baseType << " and subview memref type " << subViewType; 1878 1879 // Verify that the base memref type has a strided layout map. 1880 if (!isStrided(baseType)) 1881 return op.emitError("base type ") << baseType << " is not strided"; 1882 1883 // Verify result type against inferred type. 1884 auto expectedType = SubViewOp::inferResultType( 1885 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1886 extractFromI64ArrayAttr(op.static_sizes()), 1887 extractFromI64ArrayAttr(op.static_strides())); 1888 1889 std::string errMsg; 1890 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1891 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1892 } 1893 1894 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1895 return os << "range " << range.offset << ":" << range.size << ":" 1896 << range.stride; 1897 } 1898 1899 /// Return the list of Range (i.e. offset, size, stride). Each Range 1900 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1901 /// with `b` at location `loc`. 1902 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1903 OpBuilder &b, Location loc) { 1904 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1905 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1906 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1907 SmallVector<Range, 8> res; 1908 unsigned rank = ranks[0]; 1909 res.reserve(rank); 1910 for (unsigned idx = 0; idx < rank; ++idx) { 1911 Value offset = 1912 op.isDynamicOffset(idx) 1913 ? op.getDynamicOffset(idx) 1914 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1915 Value size = op.isDynamicSize(idx) 1916 ? op.getDynamicSize(idx) 1917 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1918 Value stride = 1919 op.isDynamicStride(idx) 1920 ? op.getDynamicStride(idx) 1921 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1922 res.emplace_back(Range{offset, size, stride}); 1923 } 1924 return res; 1925 } 1926 1927 /// Infer the canonical type of the result of a subview operation. Returns a 1928 /// type with rank `resultRank` that is either the rank of the rank-reduced 1929 /// type, or the non-rank-reduced type. 1930 static MemRefType 1931 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType, 1932 ArrayRef<OpFoldResult> mixedOffsets, 1933 ArrayRef<OpFoldResult> mixedSizes, 1934 ArrayRef<OpFoldResult> mixedStrides) { 1935 auto resultType = 1936 SubViewOp::inferRankReducedResultType( 1937 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides) 1938 .cast<MemRefType>(); 1939 if (resultType.getRank() != resultRank) { 1940 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets, 1941 mixedSizes, mixedStrides) 1942 .cast<MemRefType>(); 1943 } 1944 return resultType; 1945 } 1946 1947 namespace { 1948 /// Pattern to rewrite a subview op with MemRefCast arguments. 1949 /// This essentially pushes memref.cast past its consuming subview when 1950 /// `canFoldIntoConsumerOp` is true. 1951 /// 1952 /// Example: 1953 /// ``` 1954 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1955 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1956 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1957 /// ``` 1958 /// is rewritten into: 1959 /// ``` 1960 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1961 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1962 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1963 /// ``` 1964 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1965 public: 1966 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1967 1968 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1969 PatternRewriter &rewriter) const override { 1970 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1971 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1972 return matchPattern(operand, matchConstantIndex()); 1973 })) 1974 return failure(); 1975 1976 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1977 if (!castOp) 1978 return failure(); 1979 1980 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1981 return failure(); 1982 1983 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1984 /// the cast source operand type and the SubViewOp static information. This 1985 /// is the resulting type if the MemRefCastOp were folded. 1986 auto resultType = getCanonicalSubViewResultType( 1987 subViewOp.getType().getRank(), 1988 castOp.source().getType().cast<MemRefType>(), 1989 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1990 subViewOp.getMixedStrides()); 1991 Value newSubView = rewriter.create<SubViewOp>( 1992 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1993 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1994 subViewOp.static_sizes(), subViewOp.static_strides()); 1995 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1996 newSubView); 1997 return success(); 1998 } 1999 }; 2000 } // namespace 2001 2002 /// Return the canonical type of the result of a subview. 2003 struct SubViewReturnTypeCanonicalizer { 2004 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets, 2005 ArrayRef<OpFoldResult> mixedSizes, 2006 ArrayRef<OpFoldResult> mixedStrides) { 2007 return getCanonicalSubViewResultType(op.getType().getRank(), 2008 op.getSourceType(), mixedOffsets, 2009 mixedSizes, mixedStrides); 2010 } 2011 }; 2012 2013 /// A canonicalizer wrapper to replace SubViewOps. 2014 struct SubViewCanonicalizer { 2015 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 2016 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 2017 } 2018 }; 2019 2020 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2021 MLIRContext *context) { 2022 results 2023 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 2024 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>, 2025 SubViewOpMemRefCastFolder>(context); 2026 } 2027 2028 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 2029 auto resultShapedType = getResult().getType().cast<ShapedType>(); 2030 auto sourceShapedType = source().getType().cast<ShapedType>(); 2031 2032 if (resultShapedType.hasStaticShape() && 2033 resultShapedType == sourceShapedType) { 2034 return getViewSource(); 2035 } 2036 2037 return {}; 2038 } 2039 2040 //===----------------------------------------------------------------------===// 2041 // TensorLoadOp 2042 //===----------------------------------------------------------------------===// 2043 2044 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 2045 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 2046 // Approximate alias analysis by conservatively folding only when no there 2047 // is no interleaved operation. 2048 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 2049 bufferCast->getNextNode() == this->getOperation()) 2050 return bufferCast.tensor(); 2051 return {}; 2052 } 2053 2054 //===----------------------------------------------------------------------===// 2055 // TransposeOp 2056 //===----------------------------------------------------------------------===// 2057 2058 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 2059 static MemRefType inferTransposeResultType(MemRefType memRefType, 2060 AffineMap permutationMap) { 2061 auto rank = memRefType.getRank(); 2062 auto originalSizes = memRefType.getShape(); 2063 // Compute permuted sizes. 2064 SmallVector<int64_t, 4> sizes(rank, 0); 2065 for (auto en : llvm::enumerate(permutationMap.getResults())) 2066 sizes[en.index()] = 2067 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 2068 2069 // Compute permuted strides. 2070 int64_t offset; 2071 SmallVector<int64_t, 4> strides; 2072 auto res = getStridesAndOffset(memRefType, strides, offset); 2073 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 2074 (void)res; 2075 auto map = 2076 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 2077 map = permutationMap ? map.compose(permutationMap) : map; 2078 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 2079 } 2080 2081 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 2082 AffineMapAttr permutation, 2083 ArrayRef<NamedAttribute> attrs) { 2084 auto permutationMap = permutation.getValue(); 2085 assert(permutationMap); 2086 2087 auto memRefType = in.getType().cast<MemRefType>(); 2088 // Compute result type. 2089 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 2090 2091 build(b, result, resultType, in, attrs); 2092 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 2093 } 2094 2095 // transpose $in $permutation attr-dict : type($in) `to` type(results) 2096 static void print(OpAsmPrinter &p, TransposeOp op) { 2097 p << "memref.transpose " << op.in() << " " << op.permutation(); 2098 p.printOptionalAttrDict(op->getAttrs(), 2099 {TransposeOp::getPermutationAttrName()}); 2100 p << " : " << op.in().getType() << " to " << op.getType(); 2101 } 2102 2103 static ParseResult parseTransposeOp(OpAsmParser &parser, 2104 OperationState &result) { 2105 OpAsmParser::OperandType in; 2106 AffineMap permutation; 2107 MemRefType srcType, dstType; 2108 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 2109 parser.parseOptionalAttrDict(result.attributes) || 2110 parser.parseColonType(srcType) || 2111 parser.resolveOperand(in, srcType, result.operands) || 2112 parser.parseKeywordType("to", dstType) || 2113 parser.addTypeToList(dstType, result.types)) 2114 return failure(); 2115 2116 result.addAttribute(TransposeOp::getPermutationAttrName(), 2117 AffineMapAttr::get(permutation)); 2118 return success(); 2119 } 2120 2121 static LogicalResult verify(TransposeOp op) { 2122 if (!op.permutation().isPermutation()) 2123 return op.emitOpError("expected a permutation map"); 2124 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 2125 return op.emitOpError( 2126 "expected a permutation map of same rank as the input"); 2127 2128 auto srcType = op.in().getType().cast<MemRefType>(); 2129 auto dstType = op.getType().cast<MemRefType>(); 2130 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 2131 if (dstType != transposedType) 2132 return op.emitOpError("output type ") 2133 << dstType << " does not match transposed input type " << srcType 2134 << ", " << transposedType; 2135 return success(); 2136 } 2137 2138 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 2139 if (succeeded(foldMemRefCast(*this))) 2140 return getResult(); 2141 return {}; 2142 } 2143 2144 //===----------------------------------------------------------------------===// 2145 // ViewOp 2146 //===----------------------------------------------------------------------===// 2147 2148 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2149 OpAsmParser::OperandType srcInfo; 2150 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2151 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2152 auto indexType = parser.getBuilder().getIndexType(); 2153 Type srcType, dstType; 2154 llvm::SMLoc offsetLoc; 2155 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2156 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2157 return failure(); 2158 2159 if (offsetInfo.size() != 1) 2160 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2161 2162 return failure( 2163 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2164 parser.parseOptionalAttrDict(result.attributes) || 2165 parser.parseColonType(srcType) || 2166 parser.resolveOperand(srcInfo, srcType, result.operands) || 2167 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2168 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2169 parser.parseKeywordType("to", dstType) || 2170 parser.addTypeToList(dstType, result.types)); 2171 } 2172 2173 static void print(OpAsmPrinter &p, ViewOp op) { 2174 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2175 p.printOperand(op.byte_shift()); 2176 p << "][" << op.sizes() << ']'; 2177 p.printOptionalAttrDict(op->getAttrs()); 2178 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2179 } 2180 2181 static LogicalResult verify(ViewOp op) { 2182 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2183 auto viewType = op.getType(); 2184 2185 // The base memref should have identity layout map (or none). 2186 if (baseType.getAffineMaps().size() > 1 || 2187 (baseType.getAffineMaps().size() == 1 && 2188 !baseType.getAffineMaps()[0].isIdentity())) 2189 return op.emitError("unsupported map for base memref type ") << baseType; 2190 2191 // The result memref should have identity layout map (or none). 2192 if (viewType.getAffineMaps().size() > 1 || 2193 (viewType.getAffineMaps().size() == 1 && 2194 !viewType.getAffineMaps()[0].isIdentity())) 2195 return op.emitError("unsupported map for result memref type ") << viewType; 2196 2197 // The base memref and the view memref should be in the same memory space. 2198 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2199 return op.emitError("different memory spaces specified for base memref " 2200 "type ") 2201 << baseType << " and view memref type " << viewType; 2202 2203 // Verify that we have the correct number of sizes for the result type. 2204 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2205 if (op.sizes().size() != numDynamicDims) 2206 return op.emitError("incorrect number of size operands for type ") 2207 << viewType; 2208 2209 return success(); 2210 } 2211 2212 Value ViewOp::getViewSource() { return source(); } 2213 2214 namespace { 2215 2216 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2217 using OpRewritePattern<ViewOp>::OpRewritePattern; 2218 2219 LogicalResult matchAndRewrite(ViewOp viewOp, 2220 PatternRewriter &rewriter) const override { 2221 // Return if none of the operands are constants. 2222 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2223 return matchPattern(operand, matchConstantIndex()); 2224 })) 2225 return failure(); 2226 2227 // Get result memref type. 2228 auto memrefType = viewOp.getType(); 2229 2230 // Get offset from old memref view type 'memRefType'. 2231 int64_t oldOffset; 2232 SmallVector<int64_t, 4> oldStrides; 2233 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2234 return failure(); 2235 assert(oldOffset == 0 && "Expected 0 offset"); 2236 2237 SmallVector<Value, 4> newOperands; 2238 2239 // Offset cannot be folded into result type. 2240 2241 // Fold any dynamic dim operands which are produced by a constant. 2242 SmallVector<int64_t, 4> newShapeConstants; 2243 newShapeConstants.reserve(memrefType.getRank()); 2244 2245 unsigned dynamicDimPos = 0; 2246 unsigned rank = memrefType.getRank(); 2247 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2248 int64_t dimSize = memrefType.getDimSize(dim); 2249 // If this is already static dimension, keep it. 2250 if (!ShapedType::isDynamic(dimSize)) { 2251 newShapeConstants.push_back(dimSize); 2252 continue; 2253 } 2254 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2255 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2256 // Dynamic shape dimension will be folded. 2257 newShapeConstants.push_back(constantIndexOp.getValue()); 2258 } else { 2259 // Dynamic shape dimension not folded; copy operand from old memref. 2260 newShapeConstants.push_back(dimSize); 2261 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2262 } 2263 dynamicDimPos++; 2264 } 2265 2266 // Create new memref type with constant folded dims. 2267 MemRefType newMemRefType = 2268 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2269 // Nothing new, don't fold. 2270 if (newMemRefType == memrefType) 2271 return failure(); 2272 2273 // Create new ViewOp. 2274 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2275 viewOp.getOperand(0), 2276 viewOp.byte_shift(), newOperands); 2277 // Insert a cast so we have the same type as the old memref type. 2278 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2279 return success(); 2280 } 2281 }; 2282 2283 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2284 using OpRewritePattern<ViewOp>::OpRewritePattern; 2285 2286 LogicalResult matchAndRewrite(ViewOp viewOp, 2287 PatternRewriter &rewriter) const override { 2288 Value memrefOperand = viewOp.getOperand(0); 2289 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2290 if (!memrefCastOp) 2291 return failure(); 2292 Value allocOperand = memrefCastOp.getOperand(); 2293 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2294 if (!allocOp) 2295 return failure(); 2296 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2297 viewOp.byte_shift(), viewOp.sizes()); 2298 return success(); 2299 } 2300 }; 2301 2302 } // end anonymous namespace 2303 2304 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2305 MLIRContext *context) { 2306 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2307 } 2308 2309 //===----------------------------------------------------------------------===// 2310 // TableGen'd op method definitions 2311 //===----------------------------------------------------------------------===// 2312 2313 #define GET_OP_CLASSES 2314 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2315