1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/AffineMap.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "llvm/ADT/STLExtras.h" 20 21 using namespace mlir; 22 using namespace mlir::memref; 23 24 /// Materialize a single constant operation from a given attribute value with 25 /// the desired resultant type. 26 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 27 Attribute value, Type type, 28 Location loc) { 29 return builder.create<mlir::ConstantOp>(loc, type, value); 30 } 31 32 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 33 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 34 return llvm::to_vector<4>( 35 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 36 return a.cast<IntegerAttr>().getInt(); 37 })); 38 } 39 40 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 41 /// it is a Value or into `staticVec` if it is an IntegerAttr. 42 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 43 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 44 /// come from an AttrSizedOperandSegments trait. 45 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 46 SmallVectorImpl<Value> &dynamicVec, 47 SmallVectorImpl<int64_t> &staticVec, 48 int64_t sentinel) { 49 if (auto v = ofr.dyn_cast<Value>()) { 50 dynamicVec.push_back(v); 51 staticVec.push_back(sentinel); 52 return; 53 } 54 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 55 staticVec.push_back(apInt.getSExtValue()); 56 } 57 58 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 59 SmallVectorImpl<Value> &dynamicVec, 60 SmallVectorImpl<int64_t> &staticVec, 61 int64_t sentinel) { 62 for (auto ofr : ofrs) 63 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 64 } 65 66 //===----------------------------------------------------------------------===// 67 // Common canonicalization pattern support logic 68 //===----------------------------------------------------------------------===// 69 70 /// This is a common class used for patterns of the form 71 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 72 /// into the root operation directly. 73 static LogicalResult foldMemRefCast(Operation *op) { 74 bool folded = false; 75 for (OpOperand &operand : op->getOpOperands()) { 76 auto cast = operand.get().getDefiningOp<CastOp>(); 77 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 78 operand.set(cast.getOperand()); 79 folded = true; 80 } 81 } 82 return success(folded); 83 } 84 85 //===----------------------------------------------------------------------===// 86 // Helpers for GlobalOp 87 //===----------------------------------------------------------------------===// 88 89 static Type getTensorTypeFromMemRefType(Type type) { 90 if (auto memref = type.dyn_cast<MemRefType>()) 91 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 92 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 93 return UnrankedTensorType::get(memref.getElementType()); 94 return NoneType::get(type.getContext()); 95 } 96 97 //===----------------------------------------------------------------------===// 98 // AllocOp / AllocaOp 99 //===----------------------------------------------------------------------===// 100 101 template <typename AllocLikeOp> 102 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 103 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 104 "applies to only alloc or alloca"); 105 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 106 if (!memRefType) 107 return op.emitOpError("result must be a memref"); 108 109 if (static_cast<int64_t>(op.dynamicSizes().size()) != 110 memRefType.getNumDynamicDims()) 111 return op.emitOpError("dimension operand count does not equal memref " 112 "dynamic dimension count"); 113 114 unsigned numSymbols = 0; 115 if (!memRefType.getAffineMaps().empty()) 116 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 117 if (op.symbolOperands().size() != numSymbols) 118 return op.emitOpError( 119 "symbol operand count does not equal memref symbol count"); 120 121 return success(); 122 } 123 124 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 125 126 static LogicalResult verify(AllocaOp op) { 127 // An alloca op needs to have an ancestor with an allocation scope trait. 128 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 129 return op.emitOpError( 130 "requires an ancestor op with AutomaticAllocationScope trait"); 131 132 return verifyAllocLikeOp(op); 133 } 134 135 namespace { 136 /// Fold constant dimensions into an alloc like operation. 137 template <typename AllocLikeOp> 138 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 139 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 140 141 LogicalResult matchAndRewrite(AllocLikeOp alloc, 142 PatternRewriter &rewriter) const override { 143 // Check to see if any dimensions operands are constants. If so, we can 144 // substitute and drop them. 145 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 146 return matchPattern(operand, matchConstantIndex()); 147 })) 148 return failure(); 149 150 auto memrefType = alloc.getType(); 151 152 // Ok, we have one or more constant operands. Collect the non-constant ones 153 // and keep track of the resultant memref type to build. 154 SmallVector<int64_t, 4> newShapeConstants; 155 newShapeConstants.reserve(memrefType.getRank()); 156 SmallVector<Value, 4> newOperands; 157 158 unsigned dynamicDimPos = 0; 159 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 160 int64_t dimSize = memrefType.getDimSize(dim); 161 // If this is already static dimension, keep it. 162 if (dimSize != -1) { 163 newShapeConstants.push_back(dimSize); 164 continue; 165 } 166 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 167 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 168 // Dynamic shape dimension will be folded. 169 newShapeConstants.push_back(constantIndexOp.getValue()); 170 } else { 171 // Dynamic shape dimension not folded; copy operand from old memref. 172 newShapeConstants.push_back(-1); 173 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 174 } 175 dynamicDimPos++; 176 } 177 178 // Create new memref type (which will have fewer dynamic dimensions). 179 MemRefType newMemRefType = 180 MemRefType::Builder(memrefType).setShape(newShapeConstants); 181 assert(static_cast<int64_t>(newOperands.size()) == 182 newMemRefType.getNumDynamicDims()); 183 184 // Create and insert the alloc op for the new memref. 185 auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType, 186 newOperands, IntegerAttr()); 187 // Insert a cast so we have the same type as the old alloc. 188 auto resultCast = 189 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 190 191 rewriter.replaceOp(alloc, {resultCast}); 192 return success(); 193 } 194 }; 195 196 /// Fold alloc operations with no uses. Alloc has side effects on the heap, 197 /// but can still be deleted if it has zero uses. 198 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { 199 using OpRewritePattern<AllocOp>::OpRewritePattern; 200 201 LogicalResult matchAndRewrite(AllocOp alloc, 202 PatternRewriter &rewriter) const override { 203 if (alloc.use_empty()) { 204 rewriter.eraseOp(alloc); 205 return success(); 206 } 207 return failure(); 208 } 209 }; 210 } // end anonymous namespace. 211 212 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 213 MLIRContext *context) { 214 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context); 215 } 216 217 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 218 MLIRContext *context) { 219 results.add<SimplifyAllocConst<AllocaOp>>(context); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // AssumeAlignmentOp 224 //===----------------------------------------------------------------------===// 225 226 static LogicalResult verify(AssumeAlignmentOp op) { 227 unsigned alignment = op.alignment(); 228 if (!llvm::isPowerOf2_32(alignment)) 229 return op.emitOpError("alignment must be power of 2"); 230 return success(); 231 } 232 233 //===----------------------------------------------------------------------===// 234 // BufferCastOp 235 //===----------------------------------------------------------------------===// 236 237 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 238 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 239 if (tensorLoad.memref().getType() == getType()) 240 return tensorLoad.memref(); 241 return {}; 242 } 243 244 namespace { 245 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 246 struct BufferCast : public OpRewritePattern<BufferCastOp> { 247 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 248 249 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 250 PatternRewriter &rewriter) const final { 251 auto tensorCastOperand = 252 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 253 if (!tensorCastOperand) 254 return failure(); 255 auto srcTensorType = 256 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 257 if (!srcTensorType) 258 return failure(); 259 auto memrefType = MemRefType::get(srcTensorType.getShape(), 260 srcTensorType.getElementType()); 261 Value memref = rewriter.create<BufferCastOp>( 262 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 263 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 264 memref); 265 return success(); 266 } 267 }; 268 269 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 270 /// type mismatches prevent `BufferCastOp::fold` to kick in. 271 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 272 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 273 274 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 275 PatternRewriter &rewriter) const final { 276 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 277 // Bail unless we have a tensor_load + memref.buffer_cast with different 278 // types. `BufferCastOp::fold` handles the same type case. 279 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 280 return failure(); 281 // If types are not cast-compatible, bail. 282 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 283 bufferCast.getType())) 284 return failure(); 285 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 286 tensorLoad.memref()); 287 return success(); 288 } 289 }; 290 291 } // namespace 292 293 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 294 MLIRContext *context) { 295 results.add<BufferCast, TensorLoadToMemRef>(context); 296 } 297 298 //===----------------------------------------------------------------------===// 299 // CastOp 300 //===----------------------------------------------------------------------===// 301 302 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 303 /// source memref. This is useful to to fold a memref.cast into a consuming op 304 /// and implement canonicalization patterns for ops in different dialects that 305 /// may consume the results of memref.cast operations. Such foldable memref.cast 306 /// operations are typically inserted as `view` and `subview` ops are 307 /// canonicalized, to preserve the type compatibility of their uses. 308 /// 309 /// Returns true when all conditions are met: 310 /// 1. source and result are ranked memrefs with strided semantics and same 311 /// element type and rank. 312 /// 2. each of the source's size, offset or stride has more static information 313 /// than the corresponding result's size, offset or stride. 314 /// 315 /// Example 1: 316 /// ```mlir 317 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 318 /// %2 = consumer %1 ... : memref<?x?xf32> ... 319 /// ``` 320 /// 321 /// may fold into: 322 /// 323 /// ```mlir 324 /// %2 = consumer %0 ... : memref<8x16xf32> ... 325 /// ``` 326 /// 327 /// Example 2: 328 /// ``` 329 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 330 /// to memref<?x?xf32> 331 /// consumer %1 : memref<?x?xf32> ... 332 /// ``` 333 /// 334 /// may fold into: 335 /// 336 /// ``` 337 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 338 /// ``` 339 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 340 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 341 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 342 343 // Requires ranked MemRefType. 344 if (!sourceType || !resultType) 345 return false; 346 347 // Requires same elemental type. 348 if (sourceType.getElementType() != resultType.getElementType()) 349 return false; 350 351 // Requires same rank. 352 if (sourceType.getRank() != resultType.getRank()) 353 return false; 354 355 // Only fold casts between strided memref forms. 356 int64_t sourceOffset, resultOffset; 357 SmallVector<int64_t, 4> sourceStrides, resultStrides; 358 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 359 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 360 return false; 361 362 // If cast is towards more static sizes along any dimension, don't fold. 363 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 364 auto ss = std::get<0>(it), st = std::get<1>(it); 365 if (ss != st) 366 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 367 return false; 368 } 369 370 // If cast is towards more static offset along any dimension, don't fold. 371 if (sourceOffset != resultOffset) 372 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 373 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 374 return false; 375 376 // If cast is towards more static strides along any dimension, don't fold. 377 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 378 auto ss = std::get<0>(it), st = std::get<1>(it); 379 if (ss != st) 380 if (MemRefType::isDynamicStrideOrOffset(ss) && 381 !MemRefType::isDynamicStrideOrOffset(st)) 382 return false; 383 } 384 385 return true; 386 } 387 388 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 389 if (inputs.size() != 1 || outputs.size() != 1) 390 return false; 391 Type a = inputs.front(), b = outputs.front(); 392 auto aT = a.dyn_cast<MemRefType>(); 393 auto bT = b.dyn_cast<MemRefType>(); 394 395 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 396 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 397 398 if (aT && bT) { 399 if (aT.getElementType() != bT.getElementType()) 400 return false; 401 if (aT.getAffineMaps() != bT.getAffineMaps()) { 402 int64_t aOffset, bOffset; 403 SmallVector<int64_t, 4> aStrides, bStrides; 404 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 405 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 406 aStrides.size() != bStrides.size()) 407 return false; 408 409 // Strides along a dimension/offset are compatible if the value in the 410 // source memref is static and the value in the target memref is the 411 // same. They are also compatible if either one is dynamic (see 412 // description of MemRefCastOp for details). 413 auto checkCompatible = [](int64_t a, int64_t b) { 414 return (a == MemRefType::getDynamicStrideOrOffset() || 415 b == MemRefType::getDynamicStrideOrOffset() || a == b); 416 }; 417 if (!checkCompatible(aOffset, bOffset)) 418 return false; 419 for (auto aStride : enumerate(aStrides)) 420 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 421 return false; 422 } 423 if (aT.getMemorySpaceAsInt() != bT.getMemorySpaceAsInt()) 424 return false; 425 426 // They must have the same rank, and any specified dimensions must match. 427 if (aT.getRank() != bT.getRank()) 428 return false; 429 430 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 431 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 432 if (aDim != -1 && bDim != -1 && aDim != bDim) 433 return false; 434 } 435 return true; 436 } else { 437 if (!aT && !uaT) 438 return false; 439 if (!bT && !ubT) 440 return false; 441 // Unranked to unranked casting is unsupported 442 if (uaT && ubT) 443 return false; 444 445 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 446 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 447 if (aEltType != bEltType) 448 return false; 449 450 auto aMemSpace = 451 (aT) ? aT.getMemorySpaceAsInt() : uaT.getMemorySpaceAsInt(); 452 auto bMemSpace = 453 (bT) ? bT.getMemorySpaceAsInt() : ubT.getMemorySpaceAsInt(); 454 if (aMemSpace != bMemSpace) 455 return false; 456 457 return true; 458 } 459 460 return false; 461 } 462 463 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 464 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 465 } 466 467 //===----------------------------------------------------------------------===// 468 // DeallocOp 469 //===----------------------------------------------------------------------===// 470 namespace { 471 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 472 /// by other Dealloc operations. 473 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { 474 using OpRewritePattern<DeallocOp>::OpRewritePattern; 475 476 LogicalResult matchAndRewrite(DeallocOp dealloc, 477 PatternRewriter &rewriter) const override { 478 // Check that the memref operand's defining operation is an AllocOp. 479 Value memref = dealloc.memref(); 480 if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp())) 481 return failure(); 482 483 // Check that all of the uses of the AllocOp are other DeallocOps. 484 for (auto *user : memref.getUsers()) 485 if (!isa<DeallocOp>(user)) 486 return failure(); 487 488 // Erase the dealloc operation. 489 rewriter.eraseOp(dealloc); 490 return success(); 491 } 492 }; 493 } // end anonymous namespace. 494 495 static LogicalResult verify(DeallocOp op) { 496 if (!op.memref().getType().isa<MemRefType>()) 497 return op.emitOpError("operand must be a memref"); 498 return success(); 499 } 500 501 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, 502 MLIRContext *context) { 503 results.add<SimplifyDeadDealloc>(context); 504 } 505 506 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 507 SmallVectorImpl<OpFoldResult> &results) { 508 /// dealloc(memrefcast) -> dealloc 509 return foldMemRefCast(*this); 510 } 511 512 //===----------------------------------------------------------------------===// 513 // DimOp 514 //===----------------------------------------------------------------------===// 515 516 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 517 int64_t index) { 518 auto loc = result.location; 519 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 520 build(builder, result, memref, indexValue); 521 } 522 523 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 524 Value index) { 525 auto indexTy = builder.getIndexType(); 526 build(builder, result, indexTy, memref, index); 527 } 528 529 Optional<int64_t> DimOp::getConstantIndex() { 530 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 531 return constantOp.getValue().cast<IntegerAttr>().getInt(); 532 return {}; 533 } 534 535 static LogicalResult verify(DimOp op) { 536 // Assume unknown index to be in range. 537 Optional<int64_t> index = op.getConstantIndex(); 538 if (!index.hasValue()) 539 return success(); 540 541 // Check that constant index is not knowingly out of range. 542 auto type = op.memrefOrTensor().getType(); 543 if (auto memrefType = type.dyn_cast<MemRefType>()) { 544 if (index.getValue() >= memrefType.getRank()) 545 return op.emitOpError("index is out of range"); 546 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 547 if (index.getValue() >= tensorType.getRank()) 548 return op.emitOpError("index is out of range"); 549 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 550 // Assume index to be in range. 551 } else { 552 llvm_unreachable("expected operand with memref type"); 553 } 554 return success(); 555 } 556 557 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 558 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 559 560 // All forms of folding require a known index. 561 if (!index) 562 return {}; 563 564 auto argTy = memrefOrTensor().getType(); 565 // Fold if the shape extent along the given index is known. 566 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 567 // Folding for unranked types (UnrankedMemRefType) is not supported. 568 if (!shapedTy.hasRank()) 569 return {}; 570 if (!shapedTy.isDynamicDim(index.getInt())) { 571 Builder builder(getContext()); 572 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 573 } 574 } 575 576 Operation *definingOp = memrefOrTensor().getDefiningOp(); 577 578 // dim(memref.tensor_load(memref)) -> dim(memref) 579 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 580 setOperand(0, tensorLoadOp.memref()); 581 return getResult(); 582 } 583 584 // Fold dim to the operand of tensor.generate. 585 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 586 auto resultType = 587 fromElements.getResult().getType().cast<RankedTensorType>(); 588 // The case where the type encodes the size of the dimension is handled 589 // above. 590 assert(resultType.getShape()[index.getInt()] == 591 RankedTensorType::kDynamicSize); 592 593 // Find the operand of the fromElements that corresponds to this index. 594 auto dynExtents = fromElements.dynamicExtents().begin(); 595 for (auto dim : resultType.getShape().take_front(index.getInt())) 596 if (dim == RankedTensorType::kDynamicSize) 597 dynExtents++; 598 599 return Value{*dynExtents}; 600 } 601 602 // The size at the given index is now known to be a dynamic size. 603 unsigned unsignedIndex = index.getValue().getZExtValue(); 604 605 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 606 assert(subtensor.isDynamicSize(unsignedIndex) && 607 "Expected dynamic subtensor size"); 608 return subtensor.getDynamicSize(unsignedIndex); 609 } 610 611 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 612 auto memrefType = argTy.dyn_cast<MemRefType>(); 613 if (!memrefType) 614 return {}; 615 616 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 617 return *(alloc.getDynamicSizes().begin() + 618 memrefType.getDynamicDimIndex(unsignedIndex)); 619 620 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 621 return *(alloca.getDynamicSizes().begin() + 622 memrefType.getDynamicDimIndex(unsignedIndex)); 623 624 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 625 return *(view.getDynamicSizes().begin() + 626 memrefType.getDynamicDimIndex(unsignedIndex)); 627 628 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 629 assert(subview.isDynamicSize(unsignedIndex) && 630 "Expected dynamic subview size"); 631 return subview.getDynamicSize(unsignedIndex); 632 } 633 634 // dim(memrefcast) -> dim 635 if (succeeded(foldMemRefCast(*this))) 636 return getResult(); 637 638 return {}; 639 } 640 641 namespace { 642 /// Fold dim of a memref reshape operation to a load into the reshape's shape 643 /// operand. 644 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 645 using OpRewritePattern<DimOp>::OpRewritePattern; 646 647 LogicalResult matchAndRewrite(DimOp dim, 648 PatternRewriter &rewriter) const override { 649 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 650 651 if (!reshape) 652 return failure(); 653 654 // Place the load directly after the reshape to ensure that the shape memref 655 // was not mutated. 656 rewriter.setInsertionPointAfter(reshape); 657 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(), 658 llvm::makeArrayRef({dim.index()})); 659 return success(); 660 } 661 }; 662 663 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 664 template <typename CastOpTy> 665 struct DimOfCastOp : public OpRewritePattern<DimOp> { 666 using OpRewritePattern<DimOp>::OpRewritePattern; 667 668 LogicalResult matchAndRewrite(DimOp dimOp, 669 PatternRewriter &rewriter) const override { 670 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 671 if (!castOp) 672 return failure(); 673 Value newSource = castOp.getOperand(); 674 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 675 return success(); 676 } 677 }; 678 } // end anonymous namespace. 679 680 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 681 MLIRContext *context) { 682 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 683 DimOfCastOp<tensor::CastOp>>(context); 684 } 685 686 // --------------------------------------------------------------------------- 687 // DmaStartOp 688 // --------------------------------------------------------------------------- 689 690 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 691 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 692 ValueRange destIndices, Value numElements, 693 Value tagMemRef, ValueRange tagIndices, Value stride, 694 Value elementsPerStride) { 695 result.addOperands(srcMemRef); 696 result.addOperands(srcIndices); 697 result.addOperands(destMemRef); 698 result.addOperands(destIndices); 699 result.addOperands({numElements, tagMemRef}); 700 result.addOperands(tagIndices); 701 if (stride) 702 result.addOperands({stride, elementsPerStride}); 703 } 704 705 void DmaStartOp::print(OpAsmPrinter &p) { 706 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 707 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 708 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 709 << ']'; 710 if (isStrided()) 711 p << ", " << getStride() << ", " << getNumElementsPerStride(); 712 713 p.printOptionalAttrDict((*this)->getAttrs()); 714 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 715 << ", " << getTagMemRef().getType(); 716 } 717 718 // Parse DmaStartOp. 719 // Ex: 720 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 721 // %tag[%index], %stride, %num_elt_per_stride : 722 // : memref<3076 x f32, 0>, 723 // memref<1024 x f32, 2>, 724 // memref<1 x i32> 725 // 726 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 727 OpAsmParser::OperandType srcMemRefInfo; 728 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 729 OpAsmParser::OperandType dstMemRefInfo; 730 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 731 OpAsmParser::OperandType numElementsInfo; 732 OpAsmParser::OperandType tagMemrefInfo; 733 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 734 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 735 736 SmallVector<Type, 3> types; 737 auto indexType = parser.getBuilder().getIndexType(); 738 739 // Parse and resolve the following list of operands: 740 // *) source memref followed by its indices (in square brackets). 741 // *) destination memref followed by its indices (in square brackets). 742 // *) dma size in KiB. 743 if (parser.parseOperand(srcMemRefInfo) || 744 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 745 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 746 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 747 parser.parseComma() || parser.parseOperand(numElementsInfo) || 748 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 749 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 750 return failure(); 751 752 // Parse optional stride and elements per stride. 753 if (parser.parseTrailingOperandList(strideInfo)) 754 return failure(); 755 756 bool isStrided = strideInfo.size() == 2; 757 if (!strideInfo.empty() && !isStrided) { 758 return parser.emitError(parser.getNameLoc(), 759 "expected two stride related operands"); 760 } 761 762 if (parser.parseColonTypeList(types)) 763 return failure(); 764 if (types.size() != 3) 765 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 766 767 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 768 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 769 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 770 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 771 // size should be an index. 772 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 773 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 774 // tag indices should be index. 775 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 776 return failure(); 777 778 if (isStrided) { 779 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 780 return failure(); 781 } 782 783 return success(); 784 } 785 786 LogicalResult DmaStartOp::verify() { 787 unsigned numOperands = getNumOperands(); 788 789 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 790 // the number of elements. 791 if (numOperands < 4) 792 return emitOpError("expected at least 4 operands"); 793 794 // Check types of operands. The order of these calls is important: the later 795 // calls rely on some type properties to compute the operand position. 796 // 1. Source memref. 797 if (!getSrcMemRef().getType().isa<MemRefType>()) 798 return emitOpError("expected source to be of memref type"); 799 if (numOperands < getSrcMemRefRank() + 4) 800 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 801 << " operands"; 802 if (!getSrcIndices().empty() && 803 !llvm::all_of(getSrcIndices().getTypes(), 804 [](Type t) { return t.isIndex(); })) 805 return emitOpError("expected source indices to be of index type"); 806 807 // 2. Destination memref. 808 if (!getDstMemRef().getType().isa<MemRefType>()) 809 return emitOpError("expected destination to be of memref type"); 810 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 811 if (numOperands < numExpectedOperands) 812 return emitOpError() << "expected at least " << numExpectedOperands 813 << " operands"; 814 if (!getDstIndices().empty() && 815 !llvm::all_of(getDstIndices().getTypes(), 816 [](Type t) { return t.isIndex(); })) 817 return emitOpError("expected destination indices to be of index type"); 818 819 // 3. Number of elements. 820 if (!getNumElements().getType().isIndex()) 821 return emitOpError("expected num elements to be of index type"); 822 823 // 4. Tag memref. 824 if (!getTagMemRef().getType().isa<MemRefType>()) 825 return emitOpError("expected tag to be of memref type"); 826 numExpectedOperands += getTagMemRefRank(); 827 if (numOperands < numExpectedOperands) 828 return emitOpError() << "expected at least " << numExpectedOperands 829 << " operands"; 830 if (!getTagIndices().empty() && 831 !llvm::all_of(getTagIndices().getTypes(), 832 [](Type t) { return t.isIndex(); })) 833 return emitOpError("expected tag indices to be of index type"); 834 835 // DMAs from different memory spaces supported. 836 if (getSrcMemorySpace() == getDstMemorySpace()) 837 return emitOpError("DMA should be between different memory spaces"); 838 839 // Optional stride-related operands must be either both present or both 840 // absent. 841 if (numOperands != numExpectedOperands && 842 numOperands != numExpectedOperands + 2) 843 return emitOpError("incorrect number of operands"); 844 845 // 5. Strides. 846 if (isStrided()) { 847 if (!getStride().getType().isIndex() || 848 !getNumElementsPerStride().getType().isIndex()) 849 return emitOpError( 850 "expected stride and num elements per stride to be of type index"); 851 } 852 853 return success(); 854 } 855 856 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 857 SmallVectorImpl<OpFoldResult> &results) { 858 /// dma_start(memrefcast) -> dma_start 859 return foldMemRefCast(*this); 860 } 861 862 // --------------------------------------------------------------------------- 863 // DmaWaitOp 864 // --------------------------------------------------------------------------- 865 866 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 867 Value tagMemRef, ValueRange tagIndices, 868 Value numElements) { 869 result.addOperands(tagMemRef); 870 result.addOperands(tagIndices); 871 result.addOperands(numElements); 872 } 873 874 void DmaWaitOp::print(OpAsmPrinter &p) { 875 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 876 << "], " << getNumElements(); 877 p.printOptionalAttrDict((*this)->getAttrs()); 878 p << " : " << getTagMemRef().getType(); 879 } 880 881 // Parse DmaWaitOp. 882 // Eg: 883 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 884 // 885 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 886 OpAsmParser::OperandType tagMemrefInfo; 887 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 888 Type type; 889 auto indexType = parser.getBuilder().getIndexType(); 890 OpAsmParser::OperandType numElementsInfo; 891 892 // Parse tag memref, its indices, and dma size. 893 if (parser.parseOperand(tagMemrefInfo) || 894 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 895 parser.parseComma() || parser.parseOperand(numElementsInfo) || 896 parser.parseColonType(type) || 897 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 898 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 899 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 900 return failure(); 901 902 return success(); 903 } 904 905 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 906 SmallVectorImpl<OpFoldResult> &results) { 907 /// dma_wait(memrefcast) -> dma_wait 908 return foldMemRefCast(*this); 909 } 910 911 LogicalResult DmaWaitOp::verify() { 912 // Mandatory non-variadic operands are tag and the number of elements. 913 if (getNumOperands() < 2) 914 return emitOpError() << "expected at least 2 operands"; 915 916 // Check types of operands. The order of these calls is important: the later 917 // calls rely on some type properties to compute the operand position. 918 if (!getTagMemRef().getType().isa<MemRefType>()) 919 return emitOpError() << "expected tag to be of memref type"; 920 921 if (getNumOperands() != 2 + getTagMemRefRank()) 922 return emitOpError() << "expected " << 2 + getTagMemRefRank() 923 << " operands"; 924 925 if (!getTagIndices().empty() && 926 !llvm::all_of(getTagIndices().getTypes(), 927 [](Type t) { return t.isIndex(); })) 928 return emitOpError() << "expected tag indices to be of index type"; 929 930 if (!getNumElements().getType().isIndex()) 931 return emitOpError() 932 << "expected the number of elements to be of index type"; 933 934 return success(); 935 } 936 937 //===----------------------------------------------------------------------===// 938 // GlobalOp 939 //===----------------------------------------------------------------------===// 940 941 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 942 TypeAttr type, 943 Attribute initialValue) { 944 p << type; 945 if (!op.isExternal()) { 946 p << " = "; 947 if (op.isUninitialized()) 948 p << "uninitialized"; 949 else 950 p.printAttributeWithoutType(initialValue); 951 } 952 } 953 954 static ParseResult 955 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 956 Attribute &initialValue) { 957 Type type; 958 if (parser.parseType(type)) 959 return failure(); 960 961 auto memrefType = type.dyn_cast<MemRefType>(); 962 if (!memrefType || !memrefType.hasStaticShape()) 963 return parser.emitError(parser.getNameLoc()) 964 << "type should be static shaped memref, but got " << type; 965 typeAttr = TypeAttr::get(type); 966 967 if (parser.parseOptionalEqual()) 968 return success(); 969 970 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 971 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 972 return success(); 973 } 974 975 Type tensorType = getTensorTypeFromMemRefType(memrefType); 976 if (parser.parseAttribute(initialValue, tensorType)) 977 return failure(); 978 if (!initialValue.isa<ElementsAttr>()) 979 return parser.emitError(parser.getNameLoc()) 980 << "initial value should be a unit or elements attribute"; 981 return success(); 982 } 983 984 static LogicalResult verify(GlobalOp op) { 985 auto memrefType = op.type().dyn_cast<MemRefType>(); 986 if (!memrefType || !memrefType.hasStaticShape()) 987 return op.emitOpError("type should be static shaped memref, but got ") 988 << op.type(); 989 990 // Verify that the initial value, if present, is either a unit attribute or 991 // an elements attribute. 992 if (op.initial_value().hasValue()) { 993 Attribute initValue = op.initial_value().getValue(); 994 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 995 return op.emitOpError("initial value should be a unit or elements " 996 "attribute, but got ") 997 << initValue; 998 999 // Check that the type of the initial value is compatible with the type of 1000 // the global variable. 1001 if (initValue.isa<ElementsAttr>()) { 1002 Type initType = initValue.getType(); 1003 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1004 if (initType != tensorType) 1005 return op.emitOpError("initial value expected to be of type ") 1006 << tensorType << ", but was of type " << initType; 1007 } 1008 } 1009 1010 // TODO: verify visibility for declarations. 1011 return success(); 1012 } 1013 1014 //===----------------------------------------------------------------------===// 1015 // GetGlobalOp 1016 //===----------------------------------------------------------------------===// 1017 1018 LogicalResult 1019 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1020 // Verify that the result type is same as the type of the referenced 1021 // memref.global op. 1022 auto global = 1023 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1024 if (!global) 1025 return emitOpError("'") 1026 << name() << "' does not reference a valid global memref"; 1027 1028 Type resultType = result().getType(); 1029 if (global.type() != resultType) 1030 return emitOpError("result type ") 1031 << resultType << " does not match type " << global.type() 1032 << " of the global memref @" << name(); 1033 return success(); 1034 } 1035 1036 //===----------------------------------------------------------------------===// 1037 // LoadOp 1038 //===----------------------------------------------------------------------===// 1039 1040 static LogicalResult verify(LoadOp op) { 1041 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1042 return op.emitOpError("incorrect number of indices for load"); 1043 return success(); 1044 } 1045 1046 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1047 /// load(memrefcast) -> load 1048 if (succeeded(foldMemRefCast(*this))) 1049 return getResult(); 1050 return OpFoldResult(); 1051 } 1052 1053 namespace { 1054 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1055 /// corresponding tensor. 1056 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1057 using OpRewritePattern<LoadOp>::OpRewritePattern; 1058 1059 LogicalResult matchAndRewrite(LoadOp load, 1060 PatternRewriter &rewriter) const override { 1061 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1062 if (!buffercast) 1063 return failure(); 1064 1065 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1066 load.indices()); 1067 return success(); 1068 } 1069 }; 1070 } // end anonymous namespace. 1071 1072 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1073 MLIRContext *context) { 1074 results.add<LoadOfBufferCast>(context); 1075 } 1076 1077 //===----------------------------------------------------------------------===// 1078 // PrefetchOp 1079 //===----------------------------------------------------------------------===// 1080 1081 static void print(OpAsmPrinter &p, PrefetchOp op) { 1082 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1083 p.printOperands(op.indices()); 1084 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1085 p << ", locality<" << op.localityHint(); 1086 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1087 p.printOptionalAttrDict( 1088 op->getAttrs(), 1089 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1090 p << " : " << op.getMemRefType(); 1091 } 1092 1093 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1094 OperationState &result) { 1095 OpAsmParser::OperandType memrefInfo; 1096 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1097 IntegerAttr localityHint; 1098 MemRefType type; 1099 StringRef readOrWrite, cacheType; 1100 1101 auto indexTy = parser.getBuilder().getIndexType(); 1102 auto i32Type = parser.getBuilder().getIntegerType(32); 1103 if (parser.parseOperand(memrefInfo) || 1104 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1105 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1106 parser.parseComma() || parser.parseKeyword("locality") || 1107 parser.parseLess() || 1108 parser.parseAttribute(localityHint, i32Type, "localityHint", 1109 result.attributes) || 1110 parser.parseGreater() || parser.parseComma() || 1111 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1112 parser.resolveOperand(memrefInfo, type, result.operands) || 1113 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1114 return failure(); 1115 1116 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1117 return parser.emitError(parser.getNameLoc(), 1118 "rw specifier has to be 'read' or 'write'"); 1119 result.addAttribute( 1120 PrefetchOp::getIsWriteAttrName(), 1121 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1122 1123 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1124 return parser.emitError(parser.getNameLoc(), 1125 "cache type has to be 'data' or 'instr'"); 1126 1127 result.addAttribute( 1128 PrefetchOp::getIsDataCacheAttrName(), 1129 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1130 1131 return success(); 1132 } 1133 1134 static LogicalResult verify(PrefetchOp op) { 1135 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1136 return op.emitOpError("too few indices"); 1137 1138 return success(); 1139 } 1140 1141 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1142 SmallVectorImpl<OpFoldResult> &results) { 1143 // prefetch(memrefcast) -> prefetch 1144 return foldMemRefCast(*this); 1145 } 1146 1147 //===----------------------------------------------------------------------===// 1148 // ReinterpretCastOp 1149 //===----------------------------------------------------------------------===// 1150 1151 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1152 /// `staticSizes` and `staticStrides` are automatically filled with 1153 /// source-memref-rank sentinel values that encode dynamic entries. 1154 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1155 MemRefType resultType, Value source, 1156 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1157 ArrayRef<OpFoldResult> strides, 1158 ArrayRef<NamedAttribute> attrs) { 1159 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1160 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1161 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1162 ShapedType::kDynamicStrideOrOffset); 1163 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1164 ShapedType::kDynamicSize); 1165 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1166 ShapedType::kDynamicStrideOrOffset); 1167 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1168 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1169 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1170 result.addAttributes(attrs); 1171 } 1172 1173 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1174 MemRefType resultType, Value source, 1175 int64_t offset, ArrayRef<int64_t> sizes, 1176 ArrayRef<int64_t> strides, 1177 ArrayRef<NamedAttribute> attrs) { 1178 SmallVector<OpFoldResult> sizeValues = 1179 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1180 return b.getI64IntegerAttr(v); 1181 })); 1182 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1183 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1184 return b.getI64IntegerAttr(v); 1185 })); 1186 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1187 strideValues, attrs); 1188 } 1189 1190 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1191 MemRefType resultType, Value source, Value offset, 1192 ValueRange sizes, ValueRange strides, 1193 ArrayRef<NamedAttribute> attrs) { 1194 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1195 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1196 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1197 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1198 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1199 } 1200 1201 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1202 // completed automatically, like we have for subview and subtensor. 1203 static LogicalResult verify(ReinterpretCastOp op) { 1204 // The source and result memrefs should be in the same memory space. 1205 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1206 auto resultType = op.getType().cast<MemRefType>(); 1207 if (srcType.getMemorySpaceAsInt() != resultType.getMemorySpaceAsInt()) 1208 return op.emitError("different memory spaces specified for source type ") 1209 << srcType << " and result memref type " << resultType; 1210 if (srcType.getElementType() != resultType.getElementType()) 1211 return op.emitError("different element types specified for source type ") 1212 << srcType << " and result memref type " << resultType; 1213 1214 // Match sizes in result memref type and in static_sizes attribute. 1215 for (auto &en : 1216 llvm::enumerate(llvm::zip(resultType.getShape(), 1217 extractFromI64ArrayAttr(op.static_sizes())))) { 1218 int64_t resultSize = std::get<0>(en.value()); 1219 int64_t expectedSize = std::get<1>(en.value()); 1220 if (resultSize != expectedSize) 1221 return op.emitError("expected result type with size = ") 1222 << expectedSize << " instead of " << resultSize 1223 << " in dim = " << en.index(); 1224 } 1225 1226 // Match offset and strides in static_offset and static_strides attributes if 1227 // result memref type has an affine map specified. 1228 if (!resultType.getAffineMaps().empty()) { 1229 int64_t resultOffset; 1230 SmallVector<int64_t, 4> resultStrides; 1231 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1232 return failure(); 1233 1234 // Match offset in result memref type and in static_offsets attribute. 1235 int64_t expectedOffset = 1236 extractFromI64ArrayAttr(op.static_offsets()).front(); 1237 if (resultOffset != expectedOffset) 1238 return op.emitError("expected result type with offset = ") 1239 << resultOffset << " instead of " << expectedOffset; 1240 1241 // Match strides in result memref type and in static_strides attribute. 1242 for (auto &en : llvm::enumerate(llvm::zip( 1243 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1244 int64_t resultStride = std::get<0>(en.value()); 1245 int64_t expectedStride = std::get<1>(en.value()); 1246 if (resultStride != expectedStride) 1247 return op.emitError("expected result type with stride = ") 1248 << expectedStride << " instead of " << resultStride 1249 << " in dim = " << en.index(); 1250 } 1251 } 1252 return success(); 1253 } 1254 1255 //===----------------------------------------------------------------------===// 1256 // ReshapeOp 1257 //===----------------------------------------------------------------------===// 1258 1259 static LogicalResult verify(ReshapeOp op) { 1260 Type operandType = op.source().getType(); 1261 Type resultType = op.result().getType(); 1262 1263 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1264 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1265 if (operandElementType != resultElementType) 1266 return op.emitOpError("element types of source and destination memref " 1267 "types should be the same"); 1268 1269 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1270 if (!operandMemRefType.getAffineMaps().empty()) 1271 return op.emitOpError( 1272 "source memref type should have identity affine map"); 1273 1274 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1275 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1276 if (resultMemRefType) { 1277 if (!resultMemRefType.getAffineMaps().empty()) 1278 return op.emitOpError( 1279 "result memref type should have identity affine map"); 1280 if (shapeSize == ShapedType::kDynamicSize) 1281 return op.emitOpError("cannot use shape operand with dynamic length to " 1282 "reshape to statically-ranked memref type"); 1283 if (shapeSize != resultMemRefType.getRank()) 1284 return op.emitOpError( 1285 "length of shape operand differs from the result's memref rank"); 1286 } 1287 return success(); 1288 } 1289 1290 //===----------------------------------------------------------------------===// 1291 // StoreOp 1292 //===----------------------------------------------------------------------===// 1293 1294 static LogicalResult verify(StoreOp op) { 1295 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1296 return op.emitOpError("store index operand count not equal to memref rank"); 1297 1298 return success(); 1299 } 1300 1301 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1302 SmallVectorImpl<OpFoldResult> &results) { 1303 /// store(memrefcast) -> store 1304 return foldMemRefCast(*this); 1305 } 1306 1307 //===----------------------------------------------------------------------===// 1308 // SubViewOp 1309 //===----------------------------------------------------------------------===// 1310 1311 namespace { 1312 /// Helpers to write more idiomatic operations. 1313 namespace saturated_arith { 1314 struct Wrapper { 1315 explicit Wrapper(int64_t v) : v(v) {} 1316 operator int64_t() { return v; } 1317 int64_t v; 1318 }; 1319 Wrapper operator+(Wrapper a, int64_t b) { 1320 if (ShapedType::isDynamicStrideOrOffset(a) || 1321 ShapedType::isDynamicStrideOrOffset(b)) 1322 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1323 return Wrapper(a.v + b); 1324 } 1325 Wrapper operator*(Wrapper a, int64_t b) { 1326 if (ShapedType::isDynamicStrideOrOffset(a) || 1327 ShapedType::isDynamicStrideOrOffset(b)) 1328 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1329 return Wrapper(a.v * b); 1330 } 1331 } // end namespace saturated_arith 1332 } // end namespace 1333 1334 /// A subview result type can be fully inferred from the source type and the 1335 /// static representation of offsets, sizes and strides. Special sentinels 1336 /// encode the dynamic case. 1337 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1338 ArrayRef<int64_t> leadingStaticOffsets, 1339 ArrayRef<int64_t> leadingStaticSizes, 1340 ArrayRef<int64_t> leadingStaticStrides) { 1341 // A subview may specify only a leading subset of offset/sizes/strides in 1342 // which case we complete with offset=0, sizes from memref type and strides=1. 1343 unsigned rank = sourceMemRefType.getRank(); 1344 assert(leadingStaticOffsets.size() <= rank && 1345 "unexpected leadingStaticOffsets overflow"); 1346 assert(leadingStaticSizes.size() <= rank && 1347 "unexpected leadingStaticSizes overflow"); 1348 assert(leadingStaticStrides.size() <= rank && 1349 "unexpected leadingStaticStrides overflow"); 1350 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1351 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1352 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1353 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1354 unsigned numTrailingSizes = rank - staticSizes.size(); 1355 unsigned numTrailingStrides = rank - staticStrides.size(); 1356 staticOffsets.append(numTrailingOffsets, 0); 1357 llvm::append_range(staticSizes, 1358 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1359 staticStrides.append(numTrailingStrides, 1); 1360 1361 // Extract source offset and strides. 1362 int64_t sourceOffset; 1363 SmallVector<int64_t, 4> sourceStrides; 1364 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1365 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1366 (void)res; 1367 1368 // Compute target offset whose value is: 1369 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1370 int64_t targetOffset = sourceOffset; 1371 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1372 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1373 using namespace saturated_arith; 1374 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1375 } 1376 1377 // Compute target stride whose value is: 1378 // `sourceStrides_i * staticStrides_i`. 1379 SmallVector<int64_t, 4> targetStrides; 1380 targetStrides.reserve(staticOffsets.size()); 1381 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1382 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1383 using namespace saturated_arith; 1384 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1385 } 1386 1387 // The type is now known. 1388 return MemRefType::get( 1389 staticSizes, sourceMemRefType.getElementType(), 1390 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1391 sourceMemRefType.getContext()), 1392 sourceMemRefType.getMemorySpaceAsInt()); 1393 } 1394 1395 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1396 ArrayRef<OpFoldResult> leadingStaticOffsets, 1397 ArrayRef<OpFoldResult> leadingStaticSizes, 1398 ArrayRef<OpFoldResult> leadingStaticStrides) { 1399 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1400 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1401 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1402 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1403 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1404 ShapedType::kDynamicSize); 1405 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1406 staticStrides, ShapedType::kDynamicStrideOrOffset); 1407 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1408 staticSizes, staticStrides) 1409 .cast<MemRefType>(); 1410 } 1411 1412 Type SubViewOp::inferRankReducedResultType( 1413 unsigned resultRank, MemRefType sourceRankedTensorType, 1414 ArrayRef<int64_t> leadingStaticOffsets, 1415 ArrayRef<int64_t> leadingStaticSizes, 1416 ArrayRef<int64_t> leadingStaticStrides) { 1417 auto inferredType = 1418 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1419 leadingStaticSizes, leadingStaticStrides) 1420 .cast<MemRefType>(); 1421 assert(inferredType.getRank() >= resultRank && "expected "); 1422 int rankDiff = inferredType.getRank() - resultRank; 1423 if (rankDiff > 0) { 1424 auto shape = inferredType.getShape(); 1425 llvm::SmallDenseSet<unsigned> dimsToProject; 1426 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1427 SmallVector<int64_t> projectedShape; 1428 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1429 if (!dimsToProject.contains(pos)) 1430 projectedShape.push_back(shape[pos]); 1431 1432 AffineMap map; 1433 auto maps = inferredType.getAffineMaps(); 1434 if (!maps.empty() && maps.front()) 1435 map = getProjectedMap(maps.front(), dimsToProject); 1436 inferredType = 1437 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1438 inferredType.getMemorySpaceAsInt()); 1439 } 1440 return inferredType; 1441 } 1442 1443 Type SubViewOp::inferRankReducedResultType( 1444 unsigned resultRank, MemRefType sourceRankedTensorType, 1445 ArrayRef<OpFoldResult> leadingStaticOffsets, 1446 ArrayRef<OpFoldResult> leadingStaticSizes, 1447 ArrayRef<OpFoldResult> leadingStaticStrides) { 1448 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1449 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1450 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1451 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1452 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1453 ShapedType::kDynamicSize); 1454 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1455 staticStrides, ShapedType::kDynamicStrideOrOffset); 1456 return SubViewOp::inferRankReducedResultType( 1457 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1458 staticStrides); 1459 } 1460 // Build a SubViewOp with mixed static and dynamic entries and custom result 1461 // type. If the type passed is nullptr, it is inferred. 1462 void SubViewOp::build(OpBuilder &b, OperationState &result, 1463 MemRefType resultType, Value source, 1464 ArrayRef<OpFoldResult> offsets, 1465 ArrayRef<OpFoldResult> sizes, 1466 ArrayRef<OpFoldResult> strides, 1467 ArrayRef<NamedAttribute> attrs) { 1468 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1469 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1470 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1471 ShapedType::kDynamicStrideOrOffset); 1472 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1473 ShapedType::kDynamicSize); 1474 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1475 ShapedType::kDynamicStrideOrOffset); 1476 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1477 // Structuring implementation this way avoids duplication between builders. 1478 if (!resultType) { 1479 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1480 staticSizes, staticStrides) 1481 .cast<MemRefType>(); 1482 } 1483 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1484 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1485 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1486 result.addAttributes(attrs); 1487 } 1488 1489 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1490 // type. 1491 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1492 ArrayRef<OpFoldResult> offsets, 1493 ArrayRef<OpFoldResult> sizes, 1494 ArrayRef<OpFoldResult> strides, 1495 ArrayRef<NamedAttribute> attrs) { 1496 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1497 } 1498 1499 // Build a SubViewOp with static entries and inferred result type. 1500 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1501 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1502 ArrayRef<int64_t> strides, 1503 ArrayRef<NamedAttribute> attrs) { 1504 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1505 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1506 return b.getI64IntegerAttr(v); 1507 })); 1508 SmallVector<OpFoldResult> sizeValues = 1509 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1510 return b.getI64IntegerAttr(v); 1511 })); 1512 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1513 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1514 return b.getI64IntegerAttr(v); 1515 })); 1516 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1517 } 1518 1519 // Build a SubViewOp with dynamic entries and custom result type. If the 1520 // type passed is nullptr, it is inferred. 1521 void SubViewOp::build(OpBuilder &b, OperationState &result, 1522 MemRefType resultType, Value source, 1523 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1524 ArrayRef<int64_t> strides, 1525 ArrayRef<NamedAttribute> attrs) { 1526 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1527 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1528 return b.getI64IntegerAttr(v); 1529 })); 1530 SmallVector<OpFoldResult> sizeValues = 1531 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1532 return b.getI64IntegerAttr(v); 1533 })); 1534 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1535 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1536 return b.getI64IntegerAttr(v); 1537 })); 1538 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1539 attrs); 1540 } 1541 1542 // Build a SubViewOp with dynamic entries and custom result type. If the type 1543 // passed is nullptr, it is inferred. 1544 void SubViewOp::build(OpBuilder &b, OperationState &result, 1545 MemRefType resultType, Value source, ValueRange offsets, 1546 ValueRange sizes, ValueRange strides, 1547 ArrayRef<NamedAttribute> attrs) { 1548 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1549 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1550 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1551 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1552 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1553 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1554 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1555 } 1556 1557 // Build a SubViewOp with dynamic entries and inferred result type. 1558 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1559 ValueRange offsets, ValueRange sizes, ValueRange strides, 1560 ArrayRef<NamedAttribute> attrs) { 1561 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1562 } 1563 1564 /// For ViewLikeOpInterface. 1565 Value SubViewOp::getViewSource() { return source(); } 1566 1567 enum SubViewVerificationResult { 1568 Success, 1569 RankTooLarge, 1570 SizeMismatch, 1571 ElemTypeMismatch, 1572 MemSpaceMismatch, 1573 AffineMapMismatch 1574 }; 1575 1576 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1577 /// This function is slight variant of `is subsequence` algorithm where 1578 /// not matching dimension must be 1. 1579 static SubViewVerificationResult 1580 isRankReducedType(Type originalType, Type candidateReducedType, 1581 std::string *errMsg = nullptr) { 1582 if (originalType == candidateReducedType) 1583 return SubViewVerificationResult::Success; 1584 if (!originalType.isa<MemRefType>()) 1585 return SubViewVerificationResult::Success; 1586 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1587 return SubViewVerificationResult::Success; 1588 1589 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1590 ShapedType candidateReducedShapedType = 1591 candidateReducedType.cast<ShapedType>(); 1592 1593 // Rank and size logic is valid for all ShapedTypes. 1594 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1595 ArrayRef<int64_t> candidateReducedShape = 1596 candidateReducedShapedType.getShape(); 1597 unsigned originalRank = originalShape.size(), 1598 candidateReducedRank = candidateReducedShape.size(); 1599 if (candidateReducedRank > originalRank) 1600 return SubViewVerificationResult::RankTooLarge; 1601 1602 auto optionalUnusedDimsMask = 1603 computeRankReductionMask(originalShape, candidateReducedShape); 1604 1605 // Sizes cannot be matched in case empty vector is returned. 1606 if (!optionalUnusedDimsMask.hasValue()) 1607 return SubViewVerificationResult::SizeMismatch; 1608 1609 if (originalShapedType.getElementType() != 1610 candidateReducedShapedType.getElementType()) 1611 return SubViewVerificationResult::ElemTypeMismatch; 1612 1613 // Strided layout logic is relevant for MemRefType only. 1614 MemRefType original = originalType.cast<MemRefType>(); 1615 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1616 if (original.getMemorySpaceAsInt() != candidateReduced.getMemorySpaceAsInt()) 1617 return SubViewVerificationResult::MemSpaceMismatch; 1618 1619 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1620 auto inferredType = 1621 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1622 AffineMap candidateLayout; 1623 if (candidateReduced.getAffineMaps().empty()) 1624 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1625 else 1626 candidateLayout = candidateReduced.getAffineMaps().front(); 1627 assert(inferredType.getNumResults() == 1 && 1628 candidateLayout.getNumResults() == 1); 1629 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1630 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1631 if (errMsg) { 1632 llvm::raw_string_ostream os(*errMsg); 1633 os << "inferred type: " << inferredType; 1634 } 1635 return SubViewVerificationResult::AffineMapMismatch; 1636 } 1637 // Check that the difference of the affine maps simplifies to 0. 1638 AffineExpr diffExpr = 1639 inferredType.getResult(0) - candidateLayout.getResult(0); 1640 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1641 inferredType.getNumSymbols()); 1642 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1643 if (!(cst && cst.getValue() == 0)) { 1644 if (errMsg) { 1645 llvm::raw_string_ostream os(*errMsg); 1646 os << "inferred type: " << inferredType; 1647 } 1648 return SubViewVerificationResult::AffineMapMismatch; 1649 } 1650 return SubViewVerificationResult::Success; 1651 } 1652 1653 template <typename OpTy> 1654 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1655 OpTy op, Type expectedType, 1656 StringRef errMsg = "") { 1657 auto memrefType = expectedType.cast<ShapedType>(); 1658 switch (result) { 1659 case SubViewVerificationResult::Success: 1660 return success(); 1661 case SubViewVerificationResult::RankTooLarge: 1662 return op.emitError("expected result rank to be smaller or equal to ") 1663 << "the source rank. " << errMsg; 1664 case SubViewVerificationResult::SizeMismatch: 1665 return op.emitError("expected result type to be ") 1666 << expectedType 1667 << " or a rank-reduced version. (mismatch of result sizes) " 1668 << errMsg; 1669 case SubViewVerificationResult::ElemTypeMismatch: 1670 return op.emitError("expected result element type to be ") 1671 << memrefType.getElementType() << errMsg; 1672 case SubViewVerificationResult::MemSpaceMismatch: 1673 return op.emitError("expected result and source memory spaces to match.") 1674 << errMsg; 1675 case SubViewVerificationResult::AffineMapMismatch: 1676 return op.emitError("expected result type to be ") 1677 << expectedType 1678 << " or a rank-reduced version. (mismatch of result affine map) " 1679 << errMsg; 1680 } 1681 llvm_unreachable("unexpected subview verification result"); 1682 } 1683 1684 /// Verifier for SubViewOp. 1685 static LogicalResult verify(SubViewOp op) { 1686 MemRefType baseType = op.getSourceType(); 1687 MemRefType subViewType = op.getType(); 1688 1689 // The base memref and the view memref should be in the same memory space. 1690 if (baseType.getMemorySpaceAsInt() != subViewType.getMemorySpaceAsInt()) 1691 return op.emitError("different memory spaces specified for base memref " 1692 "type ") 1693 << baseType << " and subview memref type " << subViewType; 1694 1695 // Verify that the base memref type has a strided layout map. 1696 if (!isStrided(baseType)) 1697 return op.emitError("base type ") << baseType << " is not strided"; 1698 1699 // Verify result type against inferred type. 1700 auto expectedType = SubViewOp::inferResultType( 1701 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1702 extractFromI64ArrayAttr(op.static_sizes()), 1703 extractFromI64ArrayAttr(op.static_strides())); 1704 1705 std::string errMsg; 1706 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1707 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1708 } 1709 1710 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1711 return os << "range " << range.offset << ":" << range.size << ":" 1712 << range.stride; 1713 } 1714 1715 /// Return the list of Range (i.e. offset, size, stride). Each Range 1716 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1717 /// with `b` at location `loc`. 1718 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1719 OpBuilder &b, Location loc) { 1720 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1721 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1722 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1723 SmallVector<Range, 8> res; 1724 unsigned rank = ranks[0]; 1725 res.reserve(rank); 1726 for (unsigned idx = 0; idx < rank; ++idx) { 1727 Value offset = 1728 op.isDynamicOffset(idx) 1729 ? op.getDynamicOffset(idx) 1730 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1731 Value size = op.isDynamicSize(idx) 1732 ? op.getDynamicSize(idx) 1733 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1734 Value stride = 1735 op.isDynamicStride(idx) 1736 ? op.getDynamicStride(idx) 1737 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1738 res.emplace_back(Range{offset, size, stride}); 1739 } 1740 return res; 1741 } 1742 1743 namespace { 1744 /// Pattern to rewrite a subview op with MemRefCast arguments. 1745 /// This essentially pushes memref.cast past its consuming subview when 1746 /// `canFoldIntoConsumerOp` is true. 1747 /// 1748 /// Example: 1749 /// ``` 1750 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1751 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1752 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1753 /// ``` 1754 /// is rewritten into: 1755 /// ``` 1756 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1757 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1758 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1759 /// ``` 1760 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1761 public: 1762 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1763 1764 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1765 PatternRewriter &rewriter) const override { 1766 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1767 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1768 return matchPattern(operand, matchConstantIndex()); 1769 })) 1770 return failure(); 1771 1772 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1773 if (!castOp) 1774 return failure(); 1775 1776 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1777 return failure(); 1778 1779 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1780 /// the cast source operand type and the SubViewOp static information. This 1781 /// is the resulting type if the MemRefCastOp were folded. 1782 auto resultType = SubViewOp::inferRankReducedResultType( 1783 subViewOp.getType().getRank(), 1784 castOp.source().getType().cast<MemRefType>(), 1785 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1786 subViewOp.getMixedStrides()); 1787 Value newSubView = rewriter.create<SubViewOp>( 1788 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1789 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1790 subViewOp.static_sizes(), subViewOp.static_strides()); 1791 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1792 newSubView); 1793 return success(); 1794 } 1795 }; 1796 } // namespace 1797 1798 /// A canonicalizer wrapper to replace SubViewOps. 1799 struct SubViewCanonicalizer { 1800 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1801 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1802 } 1803 }; 1804 1805 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1806 MLIRContext *context) { 1807 results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1808 SubViewOp, SubViewCanonicalizer>, 1809 SubViewOpMemRefCastFolder>(context); 1810 } 1811 1812 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1813 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1814 auto sourceShapedType = source().getType().cast<ShapedType>(); 1815 1816 if (resultShapedType.hasStaticShape() && 1817 resultShapedType == sourceShapedType) { 1818 return getViewSource(); 1819 } 1820 1821 return {}; 1822 } 1823 1824 //===----------------------------------------------------------------------===// 1825 // TensorLoadOp 1826 //===----------------------------------------------------------------------===// 1827 1828 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1829 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1830 // Approximate alias analysis by conservatively folding only when no there 1831 // is no interleaved operation. 1832 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1833 bufferCast->getNextNode() == this->getOperation()) 1834 return bufferCast.tensor(); 1835 return {}; 1836 } 1837 1838 //===----------------------------------------------------------------------===// 1839 // TransposeOp 1840 //===----------------------------------------------------------------------===// 1841 1842 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1843 static MemRefType inferTransposeResultType(MemRefType memRefType, 1844 AffineMap permutationMap) { 1845 auto rank = memRefType.getRank(); 1846 auto originalSizes = memRefType.getShape(); 1847 // Compute permuted sizes. 1848 SmallVector<int64_t, 4> sizes(rank, 0); 1849 for (auto en : llvm::enumerate(permutationMap.getResults())) 1850 sizes[en.index()] = 1851 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 1852 1853 // Compute permuted strides. 1854 int64_t offset; 1855 SmallVector<int64_t, 4> strides; 1856 auto res = getStridesAndOffset(memRefType, strides, offset); 1857 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 1858 (void)res; 1859 auto map = 1860 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 1861 map = permutationMap ? map.compose(permutationMap) : map; 1862 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 1863 } 1864 1865 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 1866 AffineMapAttr permutation, 1867 ArrayRef<NamedAttribute> attrs) { 1868 auto permutationMap = permutation.getValue(); 1869 assert(permutationMap); 1870 1871 auto memRefType = in.getType().cast<MemRefType>(); 1872 // Compute result type. 1873 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 1874 1875 build(b, result, resultType, in, attrs); 1876 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 1877 } 1878 1879 // transpose $in $permutation attr-dict : type($in) `to` type(results) 1880 static void print(OpAsmPrinter &p, TransposeOp op) { 1881 p << "memref.transpose " << op.in() << " " << op.permutation(); 1882 p.printOptionalAttrDict(op->getAttrs(), 1883 {TransposeOp::getPermutationAttrName()}); 1884 p << " : " << op.in().getType() << " to " << op.getType(); 1885 } 1886 1887 static ParseResult parseTransposeOp(OpAsmParser &parser, 1888 OperationState &result) { 1889 OpAsmParser::OperandType in; 1890 AffineMap permutation; 1891 MemRefType srcType, dstType; 1892 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 1893 parser.parseOptionalAttrDict(result.attributes) || 1894 parser.parseColonType(srcType) || 1895 parser.resolveOperand(in, srcType, result.operands) || 1896 parser.parseKeywordType("to", dstType) || 1897 parser.addTypeToList(dstType, result.types)) 1898 return failure(); 1899 1900 result.addAttribute(TransposeOp::getPermutationAttrName(), 1901 AffineMapAttr::get(permutation)); 1902 return success(); 1903 } 1904 1905 static LogicalResult verify(TransposeOp op) { 1906 if (!op.permutation().isPermutation()) 1907 return op.emitOpError("expected a permutation map"); 1908 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 1909 return op.emitOpError( 1910 "expected a permutation map of same rank as the input"); 1911 1912 auto srcType = op.in().getType().cast<MemRefType>(); 1913 auto dstType = op.getType().cast<MemRefType>(); 1914 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 1915 if (dstType != transposedType) 1916 return op.emitOpError("output type ") 1917 << dstType << " does not match transposed input type " << srcType 1918 << ", " << transposedType; 1919 return success(); 1920 } 1921 1922 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 1923 if (succeeded(foldMemRefCast(*this))) 1924 return getResult(); 1925 return {}; 1926 } 1927 1928 //===----------------------------------------------------------------------===// 1929 // ViewOp 1930 //===----------------------------------------------------------------------===// 1931 1932 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 1933 OpAsmParser::OperandType srcInfo; 1934 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 1935 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 1936 auto indexType = parser.getBuilder().getIndexType(); 1937 Type srcType, dstType; 1938 llvm::SMLoc offsetLoc; 1939 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 1940 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 1941 return failure(); 1942 1943 if (offsetInfo.size() != 1) 1944 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 1945 1946 return failure( 1947 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 1948 parser.parseOptionalAttrDict(result.attributes) || 1949 parser.parseColonType(srcType) || 1950 parser.resolveOperand(srcInfo, srcType, result.operands) || 1951 parser.resolveOperands(offsetInfo, indexType, result.operands) || 1952 parser.resolveOperands(sizesInfo, indexType, result.operands) || 1953 parser.parseKeywordType("to", dstType) || 1954 parser.addTypeToList(dstType, result.types)); 1955 } 1956 1957 static void print(OpAsmPrinter &p, ViewOp op) { 1958 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 1959 p.printOperand(op.byte_shift()); 1960 p << "][" << op.sizes() << ']'; 1961 p.printOptionalAttrDict(op->getAttrs()); 1962 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 1963 } 1964 1965 static LogicalResult verify(ViewOp op) { 1966 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 1967 auto viewType = op.getType(); 1968 1969 // The base memref should have identity layout map (or none). 1970 if (baseType.getAffineMaps().size() > 1 || 1971 (baseType.getAffineMaps().size() == 1 && 1972 !baseType.getAffineMaps()[0].isIdentity())) 1973 return op.emitError("unsupported map for base memref type ") << baseType; 1974 1975 // The result memref should have identity layout map (or none). 1976 if (viewType.getAffineMaps().size() > 1 || 1977 (viewType.getAffineMaps().size() == 1 && 1978 !viewType.getAffineMaps()[0].isIdentity())) 1979 return op.emitError("unsupported map for result memref type ") << viewType; 1980 1981 // The base memref and the view memref should be in the same memory space. 1982 if (baseType.getMemorySpaceAsInt() != viewType.getMemorySpaceAsInt()) 1983 return op.emitError("different memory spaces specified for base memref " 1984 "type ") 1985 << baseType << " and view memref type " << viewType; 1986 1987 // Verify that we have the correct number of sizes for the result type. 1988 unsigned numDynamicDims = viewType.getNumDynamicDims(); 1989 if (op.sizes().size() != numDynamicDims) 1990 return op.emitError("incorrect number of size operands for type ") 1991 << viewType; 1992 1993 return success(); 1994 } 1995 1996 Value ViewOp::getViewSource() { return source(); } 1997 1998 namespace { 1999 2000 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2001 using OpRewritePattern<ViewOp>::OpRewritePattern; 2002 2003 LogicalResult matchAndRewrite(ViewOp viewOp, 2004 PatternRewriter &rewriter) const override { 2005 // Return if none of the operands are constants. 2006 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2007 return matchPattern(operand, matchConstantIndex()); 2008 })) 2009 return failure(); 2010 2011 // Get result memref type. 2012 auto memrefType = viewOp.getType(); 2013 2014 // Get offset from old memref view type 'memRefType'. 2015 int64_t oldOffset; 2016 SmallVector<int64_t, 4> oldStrides; 2017 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2018 return failure(); 2019 assert(oldOffset == 0 && "Expected 0 offset"); 2020 2021 SmallVector<Value, 4> newOperands; 2022 2023 // Offset cannot be folded into result type. 2024 2025 // Fold any dynamic dim operands which are produced by a constant. 2026 SmallVector<int64_t, 4> newShapeConstants; 2027 newShapeConstants.reserve(memrefType.getRank()); 2028 2029 unsigned dynamicDimPos = 0; 2030 unsigned rank = memrefType.getRank(); 2031 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2032 int64_t dimSize = memrefType.getDimSize(dim); 2033 // If this is already static dimension, keep it. 2034 if (!ShapedType::isDynamic(dimSize)) { 2035 newShapeConstants.push_back(dimSize); 2036 continue; 2037 } 2038 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2039 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2040 // Dynamic shape dimension will be folded. 2041 newShapeConstants.push_back(constantIndexOp.getValue()); 2042 } else { 2043 // Dynamic shape dimension not folded; copy operand from old memref. 2044 newShapeConstants.push_back(dimSize); 2045 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2046 } 2047 dynamicDimPos++; 2048 } 2049 2050 // Create new memref type with constant folded dims. 2051 MemRefType newMemRefType = 2052 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2053 // Nothing new, don't fold. 2054 if (newMemRefType == memrefType) 2055 return failure(); 2056 2057 // Create new ViewOp. 2058 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2059 viewOp.getOperand(0), 2060 viewOp.byte_shift(), newOperands); 2061 // Insert a cast so we have the same type as the old memref type. 2062 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2063 return success(); 2064 } 2065 }; 2066 2067 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2068 using OpRewritePattern<ViewOp>::OpRewritePattern; 2069 2070 LogicalResult matchAndRewrite(ViewOp viewOp, 2071 PatternRewriter &rewriter) const override { 2072 Value memrefOperand = viewOp.getOperand(0); 2073 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2074 if (!memrefCastOp) 2075 return failure(); 2076 Value allocOperand = memrefCastOp.getOperand(); 2077 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2078 if (!allocOp) 2079 return failure(); 2080 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2081 viewOp.byte_shift(), viewOp.sizes()); 2082 return success(); 2083 } 2084 }; 2085 2086 } // end anonymous namespace 2087 2088 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2089 MLIRContext *context) { 2090 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2091 } 2092 2093 //===----------------------------------------------------------------------===// 2094 // TableGen'd op method definitions 2095 //===----------------------------------------------------------------------===// 2096 2097 #define GET_OP_CLASSES 2098 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2099