1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 12 #include "mlir/Dialect/Tensor/IR/Tensor.h" 13 #include "mlir/IR/AffineMap.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "mlir/Interfaces/InferTypeOpInterface.h" 20 #include "llvm/ADT/STLExtras.h" 21 22 using namespace mlir; 23 using namespace mlir::memref; 24 25 /// Materialize a single constant operation from a given attribute value with 26 /// the desired resultant type. 27 Operation *MemRefDialect::materializeConstant(OpBuilder &builder, 28 Attribute value, Type type, 29 Location loc) { 30 return builder.create<mlir::ConstantOp>(loc, type, value); 31 } 32 33 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 34 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 35 return llvm::to_vector<4>( 36 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 37 return a.cast<IntegerAttr>().getInt(); 38 })); 39 } 40 41 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if 42 /// it is a Value or into `staticVec` if it is an IntegerAttr. 43 /// In the case of a Value, a copy of the `sentinel` value is also pushed to 44 /// `staticVec`. This is useful to extract mixed static and dynamic entries that 45 /// come from an AttrSizedOperandSegments trait. 46 static void dispatchIndexOpFoldResult(OpFoldResult ofr, 47 SmallVectorImpl<Value> &dynamicVec, 48 SmallVectorImpl<int64_t> &staticVec, 49 int64_t sentinel) { 50 if (auto v = ofr.dyn_cast<Value>()) { 51 dynamicVec.push_back(v); 52 staticVec.push_back(sentinel); 53 return; 54 } 55 APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue(); 56 staticVec.push_back(apInt.getSExtValue()); 57 } 58 59 static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs, 60 SmallVectorImpl<Value> &dynamicVec, 61 SmallVectorImpl<int64_t> &staticVec, 62 int64_t sentinel) { 63 for (auto ofr : ofrs) 64 dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); 65 } 66 67 //===----------------------------------------------------------------------===// 68 // Common canonicalization pattern support logic 69 //===----------------------------------------------------------------------===// 70 71 /// This is a common class used for patterns of the form 72 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast 73 /// into the root operation directly. 74 static LogicalResult foldMemRefCast(Operation *op) { 75 bool folded = false; 76 for (OpOperand &operand : op->getOpOperands()) { 77 auto cast = operand.get().getDefiningOp<CastOp>(); 78 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) { 79 operand.set(cast.getOperand()); 80 folded = true; 81 } 82 } 83 return success(folded); 84 } 85 86 //===----------------------------------------------------------------------===// 87 // Helpers for GlobalOp 88 //===----------------------------------------------------------------------===// 89 90 static Type getTensorTypeFromMemRefType(Type type) { 91 if (auto memref = type.dyn_cast<MemRefType>()) 92 return RankedTensorType::get(memref.getShape(), memref.getElementType()); 93 if (auto memref = type.dyn_cast<UnrankedMemRefType>()) 94 return UnrankedTensorType::get(memref.getElementType()); 95 return NoneType::get(type.getContext()); 96 } 97 98 //===----------------------------------------------------------------------===// 99 // AllocOp / AllocaOp 100 //===----------------------------------------------------------------------===// 101 102 template <typename AllocLikeOp> 103 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { 104 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, 105 "applies to only alloc or alloca"); 106 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); 107 if (!memRefType) 108 return op.emitOpError("result must be a memref"); 109 110 if (static_cast<int64_t>(op.dynamicSizes().size()) != 111 memRefType.getNumDynamicDims()) 112 return op.emitOpError("dimension operand count does not equal memref " 113 "dynamic dimension count"); 114 115 unsigned numSymbols = 0; 116 if (!memRefType.getAffineMaps().empty()) 117 numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); 118 if (op.symbolOperands().size() != numSymbols) 119 return op.emitOpError( 120 "symbol operand count does not equal memref symbol count"); 121 122 return success(); 123 } 124 125 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } 126 127 static LogicalResult verify(AllocaOp op) { 128 // An alloca op needs to have an ancestor with an allocation scope trait. 129 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) 130 return op.emitOpError( 131 "requires an ancestor op with AutomaticAllocationScope trait"); 132 133 return verifyAllocLikeOp(op); 134 } 135 136 namespace { 137 /// Fold constant dimensions into an alloc like operation. 138 template <typename AllocLikeOp> 139 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { 140 using OpRewritePattern<AllocLikeOp>::OpRewritePattern; 141 142 LogicalResult matchAndRewrite(AllocLikeOp alloc, 143 PatternRewriter &rewriter) const override { 144 // Check to see if any dimensions operands are constants. If so, we can 145 // substitute and drop them. 146 if (llvm::none_of(alloc.getOperands(), [](Value operand) { 147 return matchPattern(operand, matchConstantIndex()); 148 })) 149 return failure(); 150 151 auto memrefType = alloc.getType(); 152 153 // Ok, we have one or more constant operands. Collect the non-constant ones 154 // and keep track of the resultant memref type to build. 155 SmallVector<int64_t, 4> newShapeConstants; 156 newShapeConstants.reserve(memrefType.getRank()); 157 SmallVector<Value, 4> newOperands; 158 159 unsigned dynamicDimPos = 0; 160 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 161 int64_t dimSize = memrefType.getDimSize(dim); 162 // If this is already static dimension, keep it. 163 if (dimSize != -1) { 164 newShapeConstants.push_back(dimSize); 165 continue; 166 } 167 auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); 168 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 169 // Dynamic shape dimension will be folded. 170 newShapeConstants.push_back(constantIndexOp.getValue()); 171 } else { 172 // Dynamic shape dimension not folded; copy operand from old memref. 173 newShapeConstants.push_back(-1); 174 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 175 } 176 dynamicDimPos++; 177 } 178 179 // Create new memref type (which will have fewer dynamic dimensions). 180 MemRefType newMemRefType = 181 MemRefType::Builder(memrefType).setShape(newShapeConstants); 182 assert(static_cast<int64_t>(newOperands.size()) == 183 newMemRefType.getNumDynamicDims()); 184 185 // Create and insert the alloc op for the new memref. 186 auto newAlloc = rewriter.create<AllocLikeOp>(alloc.getLoc(), newMemRefType, 187 newOperands, IntegerAttr()); 188 // Insert a cast so we have the same type as the old alloc. 189 auto resultCast = 190 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType()); 191 192 rewriter.replaceOp(alloc, {resultCast}); 193 return success(); 194 } 195 }; 196 197 /// Fold alloc operations with no uses. Alloc has side effects on the heap, 198 /// but can still be deleted if it has zero uses. 199 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { 200 using OpRewritePattern<AllocOp>::OpRewritePattern; 201 202 LogicalResult matchAndRewrite(AllocOp alloc, 203 PatternRewriter &rewriter) const override { 204 if (alloc.use_empty()) { 205 rewriter.eraseOp(alloc); 206 return success(); 207 } 208 return failure(); 209 } 210 }; 211 } // end anonymous namespace. 212 213 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 214 MLIRContext *context) { 215 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc>(context); 216 } 217 218 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, 219 MLIRContext *context) { 220 results.add<SimplifyAllocConst<AllocaOp>>(context); 221 } 222 223 //===----------------------------------------------------------------------===// 224 // AssumeAlignmentOp 225 //===----------------------------------------------------------------------===// 226 227 static LogicalResult verify(AssumeAlignmentOp op) { 228 unsigned alignment = op.alignment(); 229 if (!llvm::isPowerOf2_32(alignment)) 230 return op.emitOpError("alignment must be power of 2"); 231 return success(); 232 } 233 234 //===----------------------------------------------------------------------===// 235 // BufferCastOp 236 //===----------------------------------------------------------------------===// 237 238 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) { 239 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) 240 if (tensorLoad.memref().getType() == getType()) 241 return tensorLoad.memref(); 242 return {}; 243 } 244 245 namespace { 246 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast. 247 struct BufferCast : public OpRewritePattern<BufferCastOp> { 248 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 249 250 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 251 PatternRewriter &rewriter) const final { 252 auto tensorCastOperand = 253 bufferCast.getOperand().getDefiningOp<tensor::CastOp>(); 254 if (!tensorCastOperand) 255 return failure(); 256 auto srcTensorType = 257 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); 258 if (!srcTensorType) 259 return failure(); 260 auto memrefType = MemRefType::get(srcTensorType.getShape(), 261 srcTensorType.getElementType()); 262 Value memref = rewriter.create<BufferCastOp>( 263 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand()); 264 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 265 memref); 266 return success(); 267 } 268 }; 269 270 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when 271 /// type mismatches prevent `BufferCastOp::fold` to kick in. 272 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> { 273 using OpRewritePattern<BufferCastOp>::OpRewritePattern; 274 275 LogicalResult matchAndRewrite(BufferCastOp bufferCast, 276 PatternRewriter &rewriter) const final { 277 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>(); 278 // Bail unless we have a tensor_load + memref.buffer_cast with different 279 // types. `BufferCastOp::fold` handles the same type case. 280 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) 281 return failure(); 282 // If types are not cast-compatible, bail. 283 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), 284 bufferCast.getType())) 285 return failure(); 286 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(), 287 tensorLoad.memref()); 288 return success(); 289 } 290 }; 291 292 } // namespace 293 294 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results, 295 MLIRContext *context) { 296 results.add<BufferCast, TensorLoadToMemRef>(context); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // CastOp 301 //===----------------------------------------------------------------------===// 302 303 /// Determines whether MemRef_CastOp casts to a more dynamic version of the 304 /// source memref. This is useful to to fold a memref.cast into a consuming op 305 /// and implement canonicalization patterns for ops in different dialects that 306 /// may consume the results of memref.cast operations. Such foldable memref.cast 307 /// operations are typically inserted as `view` and `subview` ops are 308 /// canonicalized, to preserve the type compatibility of their uses. 309 /// 310 /// Returns true when all conditions are met: 311 /// 1. source and result are ranked memrefs with strided semantics and same 312 /// element type and rank. 313 /// 2. each of the source's size, offset or stride has more static information 314 /// than the corresponding result's size, offset or stride. 315 /// 316 /// Example 1: 317 /// ```mlir 318 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32> 319 /// %2 = consumer %1 ... : memref<?x?xf32> ... 320 /// ``` 321 /// 322 /// may fold into: 323 /// 324 /// ```mlir 325 /// %2 = consumer %0 ... : memref<8x16xf32> ... 326 /// ``` 327 /// 328 /// Example 2: 329 /// ``` 330 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 331 /// to memref<?x?xf32> 332 /// consumer %1 : memref<?x?xf32> ... 333 /// ``` 334 /// 335 /// may fold into: 336 /// 337 /// ``` 338 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>> 339 /// ``` 340 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { 341 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>(); 342 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>(); 343 344 // Requires ranked MemRefType. 345 if (!sourceType || !resultType) 346 return false; 347 348 // Requires same elemental type. 349 if (sourceType.getElementType() != resultType.getElementType()) 350 return false; 351 352 // Requires same rank. 353 if (sourceType.getRank() != resultType.getRank()) 354 return false; 355 356 // Only fold casts between strided memref forms. 357 int64_t sourceOffset, resultOffset; 358 SmallVector<int64_t, 4> sourceStrides, resultStrides; 359 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || 360 failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 361 return false; 362 363 // If cast is towards more static sizes along any dimension, don't fold. 364 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { 365 auto ss = std::get<0>(it), st = std::get<1>(it); 366 if (ss != st) 367 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) 368 return false; 369 } 370 371 // If cast is towards more static offset along any dimension, don't fold. 372 if (sourceOffset != resultOffset) 373 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && 374 !MemRefType::isDynamicStrideOrOffset(resultOffset)) 375 return false; 376 377 // If cast is towards more static strides along any dimension, don't fold. 378 for (auto it : llvm::zip(sourceStrides, resultStrides)) { 379 auto ss = std::get<0>(it), st = std::get<1>(it); 380 if (ss != st) 381 if (MemRefType::isDynamicStrideOrOffset(ss) && 382 !MemRefType::isDynamicStrideOrOffset(st)) 383 return false; 384 } 385 386 return true; 387 } 388 389 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 390 if (inputs.size() != 1 || outputs.size() != 1) 391 return false; 392 Type a = inputs.front(), b = outputs.front(); 393 auto aT = a.dyn_cast<MemRefType>(); 394 auto bT = b.dyn_cast<MemRefType>(); 395 396 auto uaT = a.dyn_cast<UnrankedMemRefType>(); 397 auto ubT = b.dyn_cast<UnrankedMemRefType>(); 398 399 if (aT && bT) { 400 if (aT.getElementType() != bT.getElementType()) 401 return false; 402 if (aT.getAffineMaps() != bT.getAffineMaps()) { 403 int64_t aOffset, bOffset; 404 SmallVector<int64_t, 4> aStrides, bStrides; 405 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || 406 failed(getStridesAndOffset(bT, bStrides, bOffset)) || 407 aStrides.size() != bStrides.size()) 408 return false; 409 410 // Strides along a dimension/offset are compatible if the value in the 411 // source memref is static and the value in the target memref is the 412 // same. They are also compatible if either one is dynamic (see 413 // description of MemRefCastOp for details). 414 auto checkCompatible = [](int64_t a, int64_t b) { 415 return (a == MemRefType::getDynamicStrideOrOffset() || 416 b == MemRefType::getDynamicStrideOrOffset() || a == b); 417 }; 418 if (!checkCompatible(aOffset, bOffset)) 419 return false; 420 for (auto aStride : enumerate(aStrides)) 421 if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) 422 return false; 423 } 424 if (aT.getMemorySpace() != bT.getMemorySpace()) 425 return false; 426 427 // They must have the same rank, and any specified dimensions must match. 428 if (aT.getRank() != bT.getRank()) 429 return false; 430 431 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 432 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 433 if (aDim != -1 && bDim != -1 && aDim != bDim) 434 return false; 435 } 436 return true; 437 } else { 438 if (!aT && !uaT) 439 return false; 440 if (!bT && !ubT) 441 return false; 442 // Unranked to unranked casting is unsupported 443 if (uaT && ubT) 444 return false; 445 446 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); 447 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); 448 if (aEltType != bEltType) 449 return false; 450 451 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); 452 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); 453 if (aMemSpace != bMemSpace) 454 return false; 455 456 return true; 457 } 458 459 return false; 460 } 461 462 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { 463 return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); 464 } 465 466 //===----------------------------------------------------------------------===// 467 // DeallocOp 468 //===----------------------------------------------------------------------===// 469 namespace { 470 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 471 /// by other Dealloc operations. 472 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { 473 using OpRewritePattern<DeallocOp>::OpRewritePattern; 474 475 LogicalResult matchAndRewrite(DeallocOp dealloc, 476 PatternRewriter &rewriter) const override { 477 // Check that the memref operand's defining operation is an AllocOp. 478 Value memref = dealloc.memref(); 479 if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp())) 480 return failure(); 481 482 // Check that all of the uses of the AllocOp are other DeallocOps. 483 for (auto *user : memref.getUsers()) 484 if (!isa<DeallocOp>(user)) 485 return failure(); 486 487 // Erase the dealloc operation. 488 rewriter.eraseOp(dealloc); 489 return success(); 490 } 491 }; 492 } // end anonymous namespace. 493 494 static LogicalResult verify(DeallocOp op) { 495 if (!op.memref().getType().isa<MemRefType>()) 496 return op.emitOpError("operand must be a memref"); 497 return success(); 498 } 499 500 void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, 501 MLIRContext *context) { 502 results.add<SimplifyDeadDealloc>(context); 503 } 504 505 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands, 506 SmallVectorImpl<OpFoldResult> &results) { 507 /// dealloc(memrefcast) -> dealloc 508 return foldMemRefCast(*this); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // DimOp 513 //===----------------------------------------------------------------------===// 514 515 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 516 int64_t index) { 517 auto loc = result.location; 518 Value indexValue = builder.create<ConstantIndexOp>(loc, index); 519 build(builder, result, memref, indexValue); 520 } 521 522 void DimOp::build(OpBuilder &builder, OperationState &result, Value memref, 523 Value index) { 524 auto indexTy = builder.getIndexType(); 525 build(builder, result, indexTy, memref, index); 526 } 527 528 Optional<int64_t> DimOp::getConstantIndex() { 529 if (auto constantOp = index().getDefiningOp<ConstantOp>()) 530 return constantOp.getValue().cast<IntegerAttr>().getInt(); 531 return {}; 532 } 533 534 static LogicalResult verify(DimOp op) { 535 // Assume unknown index to be in range. 536 Optional<int64_t> index = op.getConstantIndex(); 537 if (!index.hasValue()) 538 return success(); 539 540 // Check that constant index is not knowingly out of range. 541 auto type = op.memrefOrTensor().getType(); 542 if (auto memrefType = type.dyn_cast<MemRefType>()) { 543 if (index.getValue() >= memrefType.getRank()) 544 return op.emitOpError("index is out of range"); 545 } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 546 if (index.getValue() >= tensorType.getRank()) 547 return op.emitOpError("index is out of range"); 548 } else if (type.isa<UnrankedMemRefType>() || type.isa<UnrankedTensorType>()) { 549 // Assume index to be in range. 550 } else { 551 llvm_unreachable("expected operand with memref type"); 552 } 553 return success(); 554 } 555 556 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 557 auto index = operands[1].dyn_cast_or_null<IntegerAttr>(); 558 559 // All forms of folding require a known index. 560 if (!index) 561 return {}; 562 563 auto argTy = memrefOrTensor().getType(); 564 // Fold if the shape extent along the given index is known. 565 if (auto shapedTy = argTy.dyn_cast<ShapedType>()) { 566 // Folding for unranked types (UnrankedMemRefType) is not supported. 567 if (!shapedTy.hasRank()) 568 return {}; 569 if (!shapedTy.isDynamicDim(index.getInt())) { 570 Builder builder(getContext()); 571 return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); 572 } 573 } 574 575 Operation *definingOp = memrefOrTensor().getDefiningOp(); 576 577 // dim(memref.tensor_load(memref)) -> dim(memref) 578 if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { 579 setOperand(0, tensorLoadOp.memref()); 580 return getResult(); 581 } 582 583 // Fold dim to the operand of tensor.generate. 584 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 585 auto resultType = 586 fromElements.getResult().getType().cast<RankedTensorType>(); 587 // The case where the type encodes the size of the dimension is handled 588 // above. 589 assert(resultType.getShape()[index.getInt()] == 590 RankedTensorType::kDynamicSize); 591 592 // Find the operand of the fromElements that corresponds to this index. 593 auto dynExtents = fromElements.dynamicExtents().begin(); 594 for (auto dim : resultType.getShape().take_front(index.getInt())) 595 if (dim == RankedTensorType::kDynamicSize) 596 dynExtents++; 597 598 return Value{*dynExtents}; 599 } 600 601 // The size at the given index is now known to be a dynamic size. 602 unsigned unsignedIndex = index.getValue().getZExtValue(); 603 604 if (auto subtensor = dyn_cast_or_null<mlir::SubTensorOp>(definingOp)) { 605 assert(subtensor.isDynamicSize(unsignedIndex) && 606 "Expected dynamic subtensor size"); 607 return subtensor.getDynamicSize(unsignedIndex); 608 } 609 610 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. 611 auto memrefType = argTy.dyn_cast<MemRefType>(); 612 if (!memrefType) 613 return {}; 614 615 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) 616 return *(alloc.getDynamicSizes().begin() + 617 memrefType.getDynamicDimIndex(unsignedIndex)); 618 619 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp)) 620 return *(alloca.getDynamicSizes().begin() + 621 memrefType.getDynamicDimIndex(unsignedIndex)); 622 623 if (auto view = dyn_cast_or_null<ViewOp>(definingOp)) 624 return *(view.getDynamicSizes().begin() + 625 memrefType.getDynamicDimIndex(unsignedIndex)); 626 627 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) { 628 assert(subview.isDynamicSize(unsignedIndex) && 629 "Expected dynamic subview size"); 630 return subview.getDynamicSize(unsignedIndex); 631 } 632 633 // dim(memrefcast) -> dim 634 if (succeeded(foldMemRefCast(*this))) 635 return getResult(); 636 637 return {}; 638 } 639 640 namespace { 641 /// Fold dim of a memref reshape operation to a load into the reshape's shape 642 /// operand. 643 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { 644 using OpRewritePattern<DimOp>::OpRewritePattern; 645 646 LogicalResult matchAndRewrite(DimOp dim, 647 PatternRewriter &rewriter) const override { 648 auto reshape = dim.memrefOrTensor().getDefiningOp<ReshapeOp>(); 649 650 if (!reshape) 651 return failure(); 652 653 // Place the load directly after the reshape to ensure that the shape memref 654 // was not mutated. 655 rewriter.setInsertionPointAfter(reshape); 656 rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(), 657 llvm::makeArrayRef({dim.index()})); 658 return success(); 659 } 660 }; 661 662 /// Fold dim of a dim of a cast into the dim of the source of the tensor cast. 663 template <typename CastOpTy> 664 struct DimOfCastOp : public OpRewritePattern<DimOp> { 665 using OpRewritePattern<DimOp>::OpRewritePattern; 666 667 LogicalResult matchAndRewrite(DimOp dimOp, 668 PatternRewriter &rewriter) const override { 669 auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); 670 if (!castOp) 671 return failure(); 672 Value newSource = castOp.getOperand(); 673 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); 674 return success(); 675 } 676 }; 677 678 /// Helper method to get the `Value` that is the shape of the `resultIdx`-th 679 /// result at dimension `dimIndex` from the `ShapedTypeOpInterface`. 680 /// TODO(ravishankarm): This is better put as a interface utility method 681 /// somewhere, but that would imply the interface will depend on the `tensor` 682 /// dialect. Ideally maybe a utility method in the `tensor` dialect. 683 static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result, 684 int64_t dimIndex) { 685 unsigned resultNumber = result.getResultNumber(); 686 auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner()); 687 Location loc = result.getOwner()->getLoc(); 688 if (!shapedTypeOp) 689 return nullptr; 690 691 // The interface exposes two methods, one that returns the shape of all the 692 // results as `Value` and other that returns the shape as a list of 693 // `SmallVector<Value>`. The former takes precedence over the latter. So first 694 // check if the op implements the first interface method or the second, and 695 // get the value to use appropriately. 696 SmallVector<Value> reifiedResultShapes; 697 if (succeeded( 698 shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) { 699 if (reifiedResultShapes.size() <= resultNumber) 700 return nullptr; 701 Value resultShape = reifiedResultShapes[resultNumber]; 702 auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); 703 if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) 704 return nullptr; 705 return builder.create<tensor::ExtractOp>( 706 loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex)); 707 } 708 709 SmallVector<SmallVector<Value>> reifiedResultShapesPerDim; 710 if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim( 711 builder, reifiedResultShapesPerDim))) 712 return nullptr; 713 if (reifiedResultShapesPerDim.size() <= resultNumber || 714 reifiedResultShapesPerDim[resultNumber].size() != 715 static_cast<size_t>(result.getType().cast<ShapedType>().getRank())) 716 return nullptr; 717 OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex]; 718 if (auto attr = valueOrAttr.dyn_cast<Attribute>()) 719 return builder.createOrFold<ConstantIndexOp>( 720 loc, attr.cast<IntegerAttr>().getInt()); 721 return valueOrAttr.get<Value>(); 722 } 723 724 /// Fold dim of an operation that implements the InferShapedTypeOpInterface 725 struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> { 726 using OpRewritePattern<DimOp>::OpRewritePattern; 727 728 LogicalResult matchAndRewrite(DimOp dimOp, 729 PatternRewriter &rewriter) const override { 730 OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>(); 731 if (!dimValue) 732 return failure(); 733 auto shapedTypeOp = 734 dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner()); 735 if (!shapedTypeOp) 736 return failure(); 737 738 Optional<int64_t> dimIndex = dimOp.getConstantIndex(); 739 if (!dimIndex) 740 return failure(); 741 Value replacement = 742 getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex); 743 if (!replacement) 744 return failure(); 745 rewriter.replaceOp(dimOp, replacement); 746 return success(); 747 } 748 }; 749 } // end anonymous namespace. 750 751 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 752 MLIRContext *context) { 753 results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>, 754 DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context); 755 } 756 757 // --------------------------------------------------------------------------- 758 // DmaStartOp 759 // --------------------------------------------------------------------------- 760 761 void DmaStartOp::build(OpBuilder &builder, OperationState &result, 762 Value srcMemRef, ValueRange srcIndices, Value destMemRef, 763 ValueRange destIndices, Value numElements, 764 Value tagMemRef, ValueRange tagIndices, Value stride, 765 Value elementsPerStride) { 766 result.addOperands(srcMemRef); 767 result.addOperands(srcIndices); 768 result.addOperands(destMemRef); 769 result.addOperands(destIndices); 770 result.addOperands({numElements, tagMemRef}); 771 result.addOperands(tagIndices); 772 if (stride) 773 result.addOperands({stride, elementsPerStride}); 774 } 775 776 void DmaStartOp::print(OpAsmPrinter &p) { 777 p << getOperationName() << " " << getSrcMemRef() << '[' << getSrcIndices() 778 << "], " << getDstMemRef() << '[' << getDstIndices() << "], " 779 << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() 780 << ']'; 781 if (isStrided()) 782 p << ", " << getStride() << ", " << getNumElementsPerStride(); 783 784 p.printOptionalAttrDict((*this)->getAttrs()); 785 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() 786 << ", " << getTagMemRef().getType(); 787 } 788 789 // Parse DmaStartOp. 790 // Ex: 791 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 792 // %tag[%index], %stride, %num_elt_per_stride : 793 // : memref<3076 x f32, 0>, 794 // memref<1024 x f32, 2>, 795 // memref<1 x i32> 796 // 797 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { 798 OpAsmParser::OperandType srcMemRefInfo; 799 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 800 OpAsmParser::OperandType dstMemRefInfo; 801 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 802 OpAsmParser::OperandType numElementsInfo; 803 OpAsmParser::OperandType tagMemrefInfo; 804 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 805 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 806 807 SmallVector<Type, 3> types; 808 auto indexType = parser.getBuilder().getIndexType(); 809 810 // Parse and resolve the following list of operands: 811 // *) source memref followed by its indices (in square brackets). 812 // *) destination memref followed by its indices (in square brackets). 813 // *) dma size in KiB. 814 if (parser.parseOperand(srcMemRefInfo) || 815 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 816 parser.parseComma() || parser.parseOperand(dstMemRefInfo) || 817 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 818 parser.parseComma() || parser.parseOperand(numElementsInfo) || 819 parser.parseComma() || parser.parseOperand(tagMemrefInfo) || 820 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 821 return failure(); 822 823 // Parse optional stride and elements per stride. 824 if (parser.parseTrailingOperandList(strideInfo)) 825 return failure(); 826 827 bool isStrided = strideInfo.size() == 2; 828 if (!strideInfo.empty() && !isStrided) { 829 return parser.emitError(parser.getNameLoc(), 830 "expected two stride related operands"); 831 } 832 833 if (parser.parseColonTypeList(types)) 834 return failure(); 835 if (types.size() != 3) 836 return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); 837 838 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || 839 parser.resolveOperands(srcIndexInfos, indexType, result.operands) || 840 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || 841 parser.resolveOperands(dstIndexInfos, indexType, result.operands) || 842 // size should be an index. 843 parser.resolveOperand(numElementsInfo, indexType, result.operands) || 844 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || 845 // tag indices should be index. 846 parser.resolveOperands(tagIndexInfos, indexType, result.operands)) 847 return failure(); 848 849 if (isStrided) { 850 if (parser.resolveOperands(strideInfo, indexType, result.operands)) 851 return failure(); 852 } 853 854 return success(); 855 } 856 857 LogicalResult DmaStartOp::verify() { 858 unsigned numOperands = getNumOperands(); 859 860 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and 861 // the number of elements. 862 if (numOperands < 4) 863 return emitOpError("expected at least 4 operands"); 864 865 // Check types of operands. The order of these calls is important: the later 866 // calls rely on some type properties to compute the operand position. 867 // 1. Source memref. 868 if (!getSrcMemRef().getType().isa<MemRefType>()) 869 return emitOpError("expected source to be of memref type"); 870 if (numOperands < getSrcMemRefRank() + 4) 871 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 872 << " operands"; 873 if (!getSrcIndices().empty() && 874 !llvm::all_of(getSrcIndices().getTypes(), 875 [](Type t) { return t.isIndex(); })) 876 return emitOpError("expected source indices to be of index type"); 877 878 // 2. Destination memref. 879 if (!getDstMemRef().getType().isa<MemRefType>()) 880 return emitOpError("expected destination to be of memref type"); 881 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; 882 if (numOperands < numExpectedOperands) 883 return emitOpError() << "expected at least " << numExpectedOperands 884 << " operands"; 885 if (!getDstIndices().empty() && 886 !llvm::all_of(getDstIndices().getTypes(), 887 [](Type t) { return t.isIndex(); })) 888 return emitOpError("expected destination indices to be of index type"); 889 890 // 3. Number of elements. 891 if (!getNumElements().getType().isIndex()) 892 return emitOpError("expected num elements to be of index type"); 893 894 // 4. Tag memref. 895 if (!getTagMemRef().getType().isa<MemRefType>()) 896 return emitOpError("expected tag to be of memref type"); 897 numExpectedOperands += getTagMemRefRank(); 898 if (numOperands < numExpectedOperands) 899 return emitOpError() << "expected at least " << numExpectedOperands 900 << " operands"; 901 if (!getTagIndices().empty() && 902 !llvm::all_of(getTagIndices().getTypes(), 903 [](Type t) { return t.isIndex(); })) 904 return emitOpError("expected tag indices to be of index type"); 905 906 // DMAs from different memory spaces supported. 907 if (getSrcMemorySpace() == getDstMemorySpace()) 908 return emitOpError("DMA should be between different memory spaces"); 909 910 // Optional stride-related operands must be either both present or both 911 // absent. 912 if (numOperands != numExpectedOperands && 913 numOperands != numExpectedOperands + 2) 914 return emitOpError("incorrect number of operands"); 915 916 // 5. Strides. 917 if (isStrided()) { 918 if (!getStride().getType().isIndex() || 919 !getNumElementsPerStride().getType().isIndex()) 920 return emitOpError( 921 "expected stride and num elements per stride to be of type index"); 922 } 923 924 return success(); 925 } 926 927 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands, 928 SmallVectorImpl<OpFoldResult> &results) { 929 /// dma_start(memrefcast) -> dma_start 930 return foldMemRefCast(*this); 931 } 932 933 // --------------------------------------------------------------------------- 934 // DmaWaitOp 935 // --------------------------------------------------------------------------- 936 937 void DmaWaitOp::build(OpBuilder &builder, OperationState &result, 938 Value tagMemRef, ValueRange tagIndices, 939 Value numElements) { 940 result.addOperands(tagMemRef); 941 result.addOperands(tagIndices); 942 result.addOperands(numElements); 943 } 944 945 void DmaWaitOp::print(OpAsmPrinter &p) { 946 p << getOperationName() << " " << getTagMemRef() << '[' << getTagIndices() 947 << "], " << getNumElements(); 948 p.printOptionalAttrDict((*this)->getAttrs()); 949 p << " : " << getTagMemRef().getType(); 950 } 951 952 // Parse DmaWaitOp. 953 // Eg: 954 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 955 // 956 ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { 957 OpAsmParser::OperandType tagMemrefInfo; 958 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 959 Type type; 960 auto indexType = parser.getBuilder().getIndexType(); 961 OpAsmParser::OperandType numElementsInfo; 962 963 // Parse tag memref, its indices, and dma size. 964 if (parser.parseOperand(tagMemrefInfo) || 965 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 966 parser.parseComma() || parser.parseOperand(numElementsInfo) || 967 parser.parseColonType(type) || 968 parser.resolveOperand(tagMemrefInfo, type, result.operands) || 969 parser.resolveOperands(tagIndexInfos, indexType, result.operands) || 970 parser.resolveOperand(numElementsInfo, indexType, result.operands)) 971 return failure(); 972 973 return success(); 974 } 975 976 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands, 977 SmallVectorImpl<OpFoldResult> &results) { 978 /// dma_wait(memrefcast) -> dma_wait 979 return foldMemRefCast(*this); 980 } 981 982 LogicalResult DmaWaitOp::verify() { 983 // Mandatory non-variadic operands are tag and the number of elements. 984 if (getNumOperands() < 2) 985 return emitOpError() << "expected at least 2 operands"; 986 987 // Check types of operands. The order of these calls is important: the later 988 // calls rely on some type properties to compute the operand position. 989 if (!getTagMemRef().getType().isa<MemRefType>()) 990 return emitOpError() << "expected tag to be of memref type"; 991 992 if (getNumOperands() != 2 + getTagMemRefRank()) 993 return emitOpError() << "expected " << 2 + getTagMemRefRank() 994 << " operands"; 995 996 if (!getTagIndices().empty() && 997 !llvm::all_of(getTagIndices().getTypes(), 998 [](Type t) { return t.isIndex(); })) 999 return emitOpError() << "expected tag indices to be of index type"; 1000 1001 if (!getNumElements().getType().isIndex()) 1002 return emitOpError() 1003 << "expected the number of elements to be of index type"; 1004 1005 return success(); 1006 } 1007 1008 //===----------------------------------------------------------------------===// 1009 // GlobalOp 1010 //===----------------------------------------------------------------------===// 1011 1012 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, 1013 TypeAttr type, 1014 Attribute initialValue) { 1015 p << type; 1016 if (!op.isExternal()) { 1017 p << " = "; 1018 if (op.isUninitialized()) 1019 p << "uninitialized"; 1020 else 1021 p.printAttributeWithoutType(initialValue); 1022 } 1023 } 1024 1025 static ParseResult 1026 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, 1027 Attribute &initialValue) { 1028 Type type; 1029 if (parser.parseType(type)) 1030 return failure(); 1031 1032 auto memrefType = type.dyn_cast<MemRefType>(); 1033 if (!memrefType || !memrefType.hasStaticShape()) 1034 return parser.emitError(parser.getNameLoc()) 1035 << "type should be static shaped memref, but got " << type; 1036 typeAttr = TypeAttr::get(type); 1037 1038 if (parser.parseOptionalEqual()) 1039 return success(); 1040 1041 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { 1042 initialValue = UnitAttr::get(parser.getBuilder().getContext()); 1043 return success(); 1044 } 1045 1046 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1047 if (parser.parseAttribute(initialValue, tensorType)) 1048 return failure(); 1049 if (!initialValue.isa<ElementsAttr>()) 1050 return parser.emitError(parser.getNameLoc()) 1051 << "initial value should be a unit or elements attribute"; 1052 return success(); 1053 } 1054 1055 static LogicalResult verify(GlobalOp op) { 1056 auto memrefType = op.type().dyn_cast<MemRefType>(); 1057 if (!memrefType || !memrefType.hasStaticShape()) 1058 return op.emitOpError("type should be static shaped memref, but got ") 1059 << op.type(); 1060 1061 // Verify that the initial value, if present, is either a unit attribute or 1062 // an elements attribute. 1063 if (op.initial_value().hasValue()) { 1064 Attribute initValue = op.initial_value().getValue(); 1065 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) 1066 return op.emitOpError("initial value should be a unit or elements " 1067 "attribute, but got ") 1068 << initValue; 1069 1070 // Check that the type of the initial value is compatible with the type of 1071 // the global variable. 1072 if (initValue.isa<ElementsAttr>()) { 1073 Type initType = initValue.getType(); 1074 Type tensorType = getTensorTypeFromMemRefType(memrefType); 1075 if (initType != tensorType) 1076 return op.emitOpError("initial value expected to be of type ") 1077 << tensorType << ", but was of type " << initType; 1078 } 1079 } 1080 1081 // TODO: verify visibility for declarations. 1082 return success(); 1083 } 1084 1085 //===----------------------------------------------------------------------===// 1086 // GetGlobalOp 1087 //===----------------------------------------------------------------------===// 1088 1089 LogicalResult 1090 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1091 // Verify that the result type is same as the type of the referenced 1092 // memref.global op. 1093 auto global = 1094 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr()); 1095 if (!global) 1096 return emitOpError("'") 1097 << name() << "' does not reference a valid global memref"; 1098 1099 Type resultType = result().getType(); 1100 if (global.type() != resultType) 1101 return emitOpError("result type ") 1102 << resultType << " does not match type " << global.type() 1103 << " of the global memref @" << name(); 1104 return success(); 1105 } 1106 1107 //===----------------------------------------------------------------------===// 1108 // LoadOp 1109 //===----------------------------------------------------------------------===// 1110 1111 static LogicalResult verify(LoadOp op) { 1112 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1113 return op.emitOpError("incorrect number of indices for load"); 1114 return success(); 1115 } 1116 1117 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) { 1118 /// load(memrefcast) -> load 1119 if (succeeded(foldMemRefCast(*this))) 1120 return getResult(); 1121 return OpFoldResult(); 1122 } 1123 1124 namespace { 1125 /// Fold a load on a buffer_cast operation into an tensor.extract on the 1126 /// corresponding tensor. 1127 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> { 1128 using OpRewritePattern<LoadOp>::OpRewritePattern; 1129 1130 LogicalResult matchAndRewrite(LoadOp load, 1131 PatternRewriter &rewriter) const override { 1132 auto buffercast = load.memref().getDefiningOp<BufferCastOp>(); 1133 if (!buffercast) 1134 return failure(); 1135 1136 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(), 1137 load.indices()); 1138 return success(); 1139 } 1140 }; 1141 } // end anonymous namespace. 1142 1143 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 1144 MLIRContext *context) { 1145 results.add<LoadOfBufferCast>(context); 1146 } 1147 1148 //===----------------------------------------------------------------------===// 1149 // PrefetchOp 1150 //===----------------------------------------------------------------------===// 1151 1152 static void print(OpAsmPrinter &p, PrefetchOp op) { 1153 p << PrefetchOp::getOperationName() << " " << op.memref() << '['; 1154 p.printOperands(op.indices()); 1155 p << ']' << ", " << (op.isWrite() ? "write" : "read"); 1156 p << ", locality<" << op.localityHint(); 1157 p << ">, " << (op.isDataCache() ? "data" : "instr"); 1158 p.printOptionalAttrDict( 1159 op->getAttrs(), 1160 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); 1161 p << " : " << op.getMemRefType(); 1162 } 1163 1164 static ParseResult parsePrefetchOp(OpAsmParser &parser, 1165 OperationState &result) { 1166 OpAsmParser::OperandType memrefInfo; 1167 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1168 IntegerAttr localityHint; 1169 MemRefType type; 1170 StringRef readOrWrite, cacheType; 1171 1172 auto indexTy = parser.getBuilder().getIndexType(); 1173 auto i32Type = parser.getBuilder().getIntegerType(32); 1174 if (parser.parseOperand(memrefInfo) || 1175 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1176 parser.parseComma() || parser.parseKeyword(&readOrWrite) || 1177 parser.parseComma() || parser.parseKeyword("locality") || 1178 parser.parseLess() || 1179 parser.parseAttribute(localityHint, i32Type, "localityHint", 1180 result.attributes) || 1181 parser.parseGreater() || parser.parseComma() || 1182 parser.parseKeyword(&cacheType) || parser.parseColonType(type) || 1183 parser.resolveOperand(memrefInfo, type, result.operands) || 1184 parser.resolveOperands(indexInfo, indexTy, result.operands)) 1185 return failure(); 1186 1187 if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) 1188 return parser.emitError(parser.getNameLoc(), 1189 "rw specifier has to be 'read' or 'write'"); 1190 result.addAttribute( 1191 PrefetchOp::getIsWriteAttrName(), 1192 parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); 1193 1194 if (!cacheType.equals("data") && !cacheType.equals("instr")) 1195 return parser.emitError(parser.getNameLoc(), 1196 "cache type has to be 'data' or 'instr'"); 1197 1198 result.addAttribute( 1199 PrefetchOp::getIsDataCacheAttrName(), 1200 parser.getBuilder().getBoolAttr(cacheType.equals("data"))); 1201 1202 return success(); 1203 } 1204 1205 static LogicalResult verify(PrefetchOp op) { 1206 if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) 1207 return op.emitOpError("too few indices"); 1208 1209 return success(); 1210 } 1211 1212 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands, 1213 SmallVectorImpl<OpFoldResult> &results) { 1214 // prefetch(memrefcast) -> prefetch 1215 return foldMemRefCast(*this); 1216 } 1217 1218 //===----------------------------------------------------------------------===// 1219 // ReinterpretCastOp 1220 //===----------------------------------------------------------------------===// 1221 1222 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, 1223 /// `staticSizes` and `staticStrides` are automatically filled with 1224 /// source-memref-rank sentinel values that encode dynamic entries. 1225 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1226 MemRefType resultType, Value source, 1227 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 1228 ArrayRef<OpFoldResult> strides, 1229 ArrayRef<NamedAttribute> attrs) { 1230 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1231 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1232 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, 1233 ShapedType::kDynamicStrideOrOffset); 1234 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1235 ShapedType::kDynamicSize); 1236 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1237 ShapedType::kDynamicStrideOrOffset); 1238 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1239 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1240 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1241 result.addAttributes(attrs); 1242 } 1243 1244 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1245 MemRefType resultType, Value source, 1246 int64_t offset, ArrayRef<int64_t> sizes, 1247 ArrayRef<int64_t> strides, 1248 ArrayRef<NamedAttribute> attrs) { 1249 SmallVector<OpFoldResult> sizeValues = 1250 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1251 return b.getI64IntegerAttr(v); 1252 })); 1253 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1254 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1255 return b.getI64IntegerAttr(v); 1256 })); 1257 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues, 1258 strideValues, attrs); 1259 } 1260 1261 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, 1262 MemRefType resultType, Value source, Value offset, 1263 ValueRange sizes, ValueRange strides, 1264 ArrayRef<NamedAttribute> attrs) { 1265 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1266 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1267 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1268 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1269 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs); 1270 } 1271 1272 // TODO: ponder whether we want to allow missing trailing sizes/strides that are 1273 // completed automatically, like we have for subview and subtensor. 1274 static LogicalResult verify(ReinterpretCastOp op) { 1275 // The source and result memrefs should be in the same memory space. 1276 auto srcType = op.source().getType().cast<BaseMemRefType>(); 1277 auto resultType = op.getType().cast<MemRefType>(); 1278 if (srcType.getMemorySpace() != resultType.getMemorySpace()) 1279 return op.emitError("different memory spaces specified for source type ") 1280 << srcType << " and result memref type " << resultType; 1281 if (srcType.getElementType() != resultType.getElementType()) 1282 return op.emitError("different element types specified for source type ") 1283 << srcType << " and result memref type " << resultType; 1284 1285 // Match sizes in result memref type and in static_sizes attribute. 1286 for (auto &en : 1287 llvm::enumerate(llvm::zip(resultType.getShape(), 1288 extractFromI64ArrayAttr(op.static_sizes())))) { 1289 int64_t resultSize = std::get<0>(en.value()); 1290 int64_t expectedSize = std::get<1>(en.value()); 1291 if (resultSize != expectedSize) 1292 return op.emitError("expected result type with size = ") 1293 << expectedSize << " instead of " << resultSize 1294 << " in dim = " << en.index(); 1295 } 1296 1297 // Match offset and strides in static_offset and static_strides attributes if 1298 // result memref type has an affine map specified. 1299 if (!resultType.getAffineMaps().empty()) { 1300 int64_t resultOffset; 1301 SmallVector<int64_t, 4> resultStrides; 1302 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) 1303 return failure(); 1304 1305 // Match offset in result memref type and in static_offsets attribute. 1306 int64_t expectedOffset = 1307 extractFromI64ArrayAttr(op.static_offsets()).front(); 1308 if (resultOffset != expectedOffset) 1309 return op.emitError("expected result type with offset = ") 1310 << resultOffset << " instead of " << expectedOffset; 1311 1312 // Match strides in result memref type and in static_strides attribute. 1313 for (auto &en : llvm::enumerate(llvm::zip( 1314 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { 1315 int64_t resultStride = std::get<0>(en.value()); 1316 int64_t expectedStride = std::get<1>(en.value()); 1317 if (resultStride != expectedStride) 1318 return op.emitError("expected result type with stride = ") 1319 << expectedStride << " instead of " << resultStride 1320 << " in dim = " << en.index(); 1321 } 1322 } 1323 return success(); 1324 } 1325 1326 //===----------------------------------------------------------------------===// 1327 // ReshapeOp 1328 //===----------------------------------------------------------------------===// 1329 1330 static LogicalResult verify(ReshapeOp op) { 1331 Type operandType = op.source().getType(); 1332 Type resultType = op.result().getType(); 1333 1334 Type operandElementType = operandType.cast<ShapedType>().getElementType(); 1335 Type resultElementType = resultType.cast<ShapedType>().getElementType(); 1336 if (operandElementType != resultElementType) 1337 return op.emitOpError("element types of source and destination memref " 1338 "types should be the same"); 1339 1340 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) 1341 if (!operandMemRefType.getAffineMaps().empty()) 1342 return op.emitOpError( 1343 "source memref type should have identity affine map"); 1344 1345 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); 1346 auto resultMemRefType = resultType.dyn_cast<MemRefType>(); 1347 if (resultMemRefType) { 1348 if (!resultMemRefType.getAffineMaps().empty()) 1349 return op.emitOpError( 1350 "result memref type should have identity affine map"); 1351 if (shapeSize == ShapedType::kDynamicSize) 1352 return op.emitOpError("cannot use shape operand with dynamic length to " 1353 "reshape to statically-ranked memref type"); 1354 if (shapeSize != resultMemRefType.getRank()) 1355 return op.emitOpError( 1356 "length of shape operand differs from the result's memref rank"); 1357 } 1358 return success(); 1359 } 1360 1361 //===----------------------------------------------------------------------===// 1362 // StoreOp 1363 //===----------------------------------------------------------------------===// 1364 1365 static LogicalResult verify(StoreOp op) { 1366 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 1367 return op.emitOpError("store index operand count not equal to memref rank"); 1368 1369 return success(); 1370 } 1371 1372 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands, 1373 SmallVectorImpl<OpFoldResult> &results) { 1374 /// store(memrefcast) -> store 1375 return foldMemRefCast(*this); 1376 } 1377 1378 //===----------------------------------------------------------------------===// 1379 // SubViewOp 1380 //===----------------------------------------------------------------------===// 1381 1382 namespace { 1383 /// Helpers to write more idiomatic operations. 1384 namespace saturated_arith { 1385 struct Wrapper { 1386 explicit Wrapper(int64_t v) : v(v) {} 1387 operator int64_t() { return v; } 1388 int64_t v; 1389 }; 1390 Wrapper operator+(Wrapper a, int64_t b) { 1391 if (ShapedType::isDynamicStrideOrOffset(a) || 1392 ShapedType::isDynamicStrideOrOffset(b)) 1393 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1394 return Wrapper(a.v + b); 1395 } 1396 Wrapper operator*(Wrapper a, int64_t b) { 1397 if (ShapedType::isDynamicStrideOrOffset(a) || 1398 ShapedType::isDynamicStrideOrOffset(b)) 1399 return Wrapper(ShapedType::kDynamicStrideOrOffset); 1400 return Wrapper(a.v * b); 1401 } 1402 } // end namespace saturated_arith 1403 } // end namespace 1404 1405 /// A subview result type can be fully inferred from the source type and the 1406 /// static representation of offsets, sizes and strides. Special sentinels 1407 /// encode the dynamic case. 1408 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1409 ArrayRef<int64_t> leadingStaticOffsets, 1410 ArrayRef<int64_t> leadingStaticSizes, 1411 ArrayRef<int64_t> leadingStaticStrides) { 1412 // A subview may specify only a leading subset of offset/sizes/strides in 1413 // which case we complete with offset=0, sizes from memref type and strides=1. 1414 unsigned rank = sourceMemRefType.getRank(); 1415 assert(leadingStaticOffsets.size() <= rank && 1416 "unexpected leadingStaticOffsets overflow"); 1417 assert(leadingStaticSizes.size() <= rank && 1418 "unexpected leadingStaticSizes overflow"); 1419 assert(leadingStaticStrides.size() <= rank && 1420 "unexpected leadingStaticStrides overflow"); 1421 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); 1422 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); 1423 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); 1424 unsigned numTrailingOffsets = rank - staticOffsets.size(); 1425 unsigned numTrailingSizes = rank - staticSizes.size(); 1426 unsigned numTrailingStrides = rank - staticStrides.size(); 1427 staticOffsets.append(numTrailingOffsets, 0); 1428 llvm::append_range(staticSizes, 1429 sourceMemRefType.getShape().take_back(numTrailingSizes)); 1430 staticStrides.append(numTrailingStrides, 1); 1431 1432 // Extract source offset and strides. 1433 int64_t sourceOffset; 1434 SmallVector<int64_t, 4> sourceStrides; 1435 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); 1436 assert(succeeded(res) && "SubViewOp expected strided memref type"); 1437 (void)res; 1438 1439 // Compute target offset whose value is: 1440 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. 1441 int64_t targetOffset = sourceOffset; 1442 for (auto it : llvm::zip(staticOffsets, sourceStrides)) { 1443 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); 1444 using namespace saturated_arith; 1445 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; 1446 } 1447 1448 // Compute target stride whose value is: 1449 // `sourceStrides_i * staticStrides_i`. 1450 SmallVector<int64_t, 4> targetStrides; 1451 targetStrides.reserve(staticOffsets.size()); 1452 for (auto it : llvm::zip(sourceStrides, staticStrides)) { 1453 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); 1454 using namespace saturated_arith; 1455 targetStrides.push_back(Wrapper(sourceStride) * staticStride); 1456 } 1457 1458 // The type is now known. 1459 return MemRefType::get( 1460 staticSizes, sourceMemRefType.getElementType(), 1461 makeStridedLinearLayoutMap(targetStrides, targetOffset, 1462 sourceMemRefType.getContext()), 1463 sourceMemRefType.getMemorySpace()); 1464 } 1465 1466 Type SubViewOp::inferResultType(MemRefType sourceMemRefType, 1467 ArrayRef<OpFoldResult> leadingStaticOffsets, 1468 ArrayRef<OpFoldResult> leadingStaticSizes, 1469 ArrayRef<OpFoldResult> leadingStaticStrides) { 1470 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1471 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1472 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1473 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1474 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1475 ShapedType::kDynamicSize); 1476 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1477 staticStrides, ShapedType::kDynamicStrideOrOffset); 1478 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1479 staticSizes, staticStrides) 1480 .cast<MemRefType>(); 1481 } 1482 1483 Type SubViewOp::inferRankReducedResultType( 1484 unsigned resultRank, MemRefType sourceRankedTensorType, 1485 ArrayRef<int64_t> leadingStaticOffsets, 1486 ArrayRef<int64_t> leadingStaticSizes, 1487 ArrayRef<int64_t> leadingStaticStrides) { 1488 auto inferredType = 1489 inferResultType(sourceRankedTensorType, leadingStaticOffsets, 1490 leadingStaticSizes, leadingStaticStrides) 1491 .cast<MemRefType>(); 1492 assert(inferredType.getRank() >= resultRank && "expected "); 1493 int rankDiff = inferredType.getRank() - resultRank; 1494 if (rankDiff > 0) { 1495 auto shape = inferredType.getShape(); 1496 llvm::SmallDenseSet<unsigned> dimsToProject; 1497 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject); 1498 SmallVector<int64_t> projectedShape; 1499 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 1500 if (!dimsToProject.contains(pos)) 1501 projectedShape.push_back(shape[pos]); 1502 1503 AffineMap map; 1504 auto maps = inferredType.getAffineMaps(); 1505 if (!maps.empty() && maps.front()) 1506 map = getProjectedMap(maps.front(), dimsToProject); 1507 inferredType = 1508 MemRefType::get(projectedShape, inferredType.getElementType(), map, 1509 inferredType.getMemorySpace()); 1510 } 1511 return inferredType; 1512 } 1513 1514 Type SubViewOp::inferRankReducedResultType( 1515 unsigned resultRank, MemRefType sourceRankedTensorType, 1516 ArrayRef<OpFoldResult> leadingStaticOffsets, 1517 ArrayRef<OpFoldResult> leadingStaticSizes, 1518 ArrayRef<OpFoldResult> leadingStaticStrides) { 1519 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1520 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1521 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, 1522 staticOffsets, ShapedType::kDynamicStrideOrOffset); 1523 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, 1524 ShapedType::kDynamicSize); 1525 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, 1526 staticStrides, ShapedType::kDynamicStrideOrOffset); 1527 return SubViewOp::inferRankReducedResultType( 1528 resultRank, sourceRankedTensorType, staticOffsets, staticSizes, 1529 staticStrides); 1530 } 1531 // Build a SubViewOp with mixed static and dynamic entries and custom result 1532 // type. If the type passed is nullptr, it is inferred. 1533 void SubViewOp::build(OpBuilder &b, OperationState &result, 1534 MemRefType resultType, Value source, 1535 ArrayRef<OpFoldResult> offsets, 1536 ArrayRef<OpFoldResult> sizes, 1537 ArrayRef<OpFoldResult> strides, 1538 ArrayRef<NamedAttribute> attrs) { 1539 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 1540 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 1541 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, 1542 ShapedType::kDynamicStrideOrOffset); 1543 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, 1544 ShapedType::kDynamicSize); 1545 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, 1546 ShapedType::kDynamicStrideOrOffset); 1547 auto sourceMemRefType = source.getType().cast<MemRefType>(); 1548 // Structuring implementation this way avoids duplication between builders. 1549 if (!resultType) { 1550 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets, 1551 staticSizes, staticStrides) 1552 .cast<MemRefType>(); 1553 } 1554 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 1555 dynamicStrides, b.getI64ArrayAttr(staticOffsets), 1556 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); 1557 result.addAttributes(attrs); 1558 } 1559 1560 // Build a SubViewOp with mixed static and dynamic entries and inferred result 1561 // type. 1562 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1563 ArrayRef<OpFoldResult> offsets, 1564 ArrayRef<OpFoldResult> sizes, 1565 ArrayRef<OpFoldResult> strides, 1566 ArrayRef<NamedAttribute> attrs) { 1567 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1568 } 1569 1570 // Build a SubViewOp with static entries and inferred result type. 1571 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1572 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1573 ArrayRef<int64_t> strides, 1574 ArrayRef<NamedAttribute> attrs) { 1575 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1576 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1577 return b.getI64IntegerAttr(v); 1578 })); 1579 SmallVector<OpFoldResult> sizeValues = 1580 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1581 return b.getI64IntegerAttr(v); 1582 })); 1583 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1584 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1585 return b.getI64IntegerAttr(v); 1586 })); 1587 build(b, result, source, offsetValues, sizeValues, strideValues, attrs); 1588 } 1589 1590 // Build a SubViewOp with dynamic entries and custom result type. If the 1591 // type passed is nullptr, it is inferred. 1592 void SubViewOp::build(OpBuilder &b, OperationState &result, 1593 MemRefType resultType, Value source, 1594 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 1595 ArrayRef<int64_t> strides, 1596 ArrayRef<NamedAttribute> attrs) { 1597 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1598 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult { 1599 return b.getI64IntegerAttr(v); 1600 })); 1601 SmallVector<OpFoldResult> sizeValues = 1602 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult { 1603 return b.getI64IntegerAttr(v); 1604 })); 1605 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1606 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult { 1607 return b.getI64IntegerAttr(v); 1608 })); 1609 build(b, result, resultType, source, offsetValues, sizeValues, strideValues, 1610 attrs); 1611 } 1612 1613 // Build a SubViewOp with dynamic entries and custom result type. If the type 1614 // passed is nullptr, it is inferred. 1615 void SubViewOp::build(OpBuilder &b, OperationState &result, 1616 MemRefType resultType, Value source, ValueRange offsets, 1617 ValueRange sizes, ValueRange strides, 1618 ArrayRef<NamedAttribute> attrs) { 1619 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 1620 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 1621 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 1622 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 1623 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 1624 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 1625 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 1626 } 1627 1628 // Build a SubViewOp with dynamic entries and inferred result type. 1629 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source, 1630 ValueRange offsets, ValueRange sizes, ValueRange strides, 1631 ArrayRef<NamedAttribute> attrs) { 1632 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs); 1633 } 1634 1635 /// For ViewLikeOpInterface. 1636 Value SubViewOp::getViewSource() { return source(); } 1637 1638 enum SubViewVerificationResult { 1639 Success, 1640 RankTooLarge, 1641 SizeMismatch, 1642 ElemTypeMismatch, 1643 MemSpaceMismatch, 1644 AffineMapMismatch 1645 }; 1646 1647 /// Checks if `original` Type type can be rank reduced to `reduced` type. 1648 /// This function is slight variant of `is subsequence` algorithm where 1649 /// not matching dimension must be 1. 1650 static SubViewVerificationResult 1651 isRankReducedType(Type originalType, Type candidateReducedType, 1652 std::string *errMsg = nullptr) { 1653 if (originalType == candidateReducedType) 1654 return SubViewVerificationResult::Success; 1655 if (!originalType.isa<MemRefType>()) 1656 return SubViewVerificationResult::Success; 1657 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>()) 1658 return SubViewVerificationResult::Success; 1659 1660 ShapedType originalShapedType = originalType.cast<ShapedType>(); 1661 ShapedType candidateReducedShapedType = 1662 candidateReducedType.cast<ShapedType>(); 1663 1664 // Rank and size logic is valid for all ShapedTypes. 1665 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 1666 ArrayRef<int64_t> candidateReducedShape = 1667 candidateReducedShapedType.getShape(); 1668 unsigned originalRank = originalShape.size(), 1669 candidateReducedRank = candidateReducedShape.size(); 1670 if (candidateReducedRank > originalRank) 1671 return SubViewVerificationResult::RankTooLarge; 1672 1673 auto optionalUnusedDimsMask = 1674 computeRankReductionMask(originalShape, candidateReducedShape); 1675 1676 // Sizes cannot be matched in case empty vector is returned. 1677 if (!optionalUnusedDimsMask.hasValue()) 1678 return SubViewVerificationResult::SizeMismatch; 1679 1680 if (originalShapedType.getElementType() != 1681 candidateReducedShapedType.getElementType()) 1682 return SubViewVerificationResult::ElemTypeMismatch; 1683 1684 // Strided layout logic is relevant for MemRefType only. 1685 MemRefType original = originalType.cast<MemRefType>(); 1686 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>(); 1687 if (original.getMemorySpace() != candidateReduced.getMemorySpace()) 1688 return SubViewVerificationResult::MemSpaceMismatch; 1689 1690 llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue(); 1691 auto inferredType = 1692 getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); 1693 AffineMap candidateLayout; 1694 if (candidateReduced.getAffineMaps().empty()) 1695 candidateLayout = getStridedLinearLayoutMap(candidateReduced); 1696 else 1697 candidateLayout = candidateReduced.getAffineMaps().front(); 1698 assert(inferredType.getNumResults() == 1 && 1699 candidateLayout.getNumResults() == 1); 1700 if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || 1701 inferredType.getNumDims() != candidateLayout.getNumDims()) { 1702 if (errMsg) { 1703 llvm::raw_string_ostream os(*errMsg); 1704 os << "inferred type: " << inferredType; 1705 } 1706 return SubViewVerificationResult::AffineMapMismatch; 1707 } 1708 // Check that the difference of the affine maps simplifies to 0. 1709 AffineExpr diffExpr = 1710 inferredType.getResult(0) - candidateLayout.getResult(0); 1711 diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), 1712 inferredType.getNumSymbols()); 1713 auto cst = diffExpr.dyn_cast<AffineConstantExpr>(); 1714 if (!(cst && cst.getValue() == 0)) { 1715 if (errMsg) { 1716 llvm::raw_string_ostream os(*errMsg); 1717 os << "inferred type: " << inferredType; 1718 } 1719 return SubViewVerificationResult::AffineMapMismatch; 1720 } 1721 return SubViewVerificationResult::Success; 1722 } 1723 1724 template <typename OpTy> 1725 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, 1726 OpTy op, Type expectedType, 1727 StringRef errMsg = "") { 1728 auto memrefType = expectedType.cast<ShapedType>(); 1729 switch (result) { 1730 case SubViewVerificationResult::Success: 1731 return success(); 1732 case SubViewVerificationResult::RankTooLarge: 1733 return op.emitError("expected result rank to be smaller or equal to ") 1734 << "the source rank. " << errMsg; 1735 case SubViewVerificationResult::SizeMismatch: 1736 return op.emitError("expected result type to be ") 1737 << expectedType 1738 << " or a rank-reduced version. (mismatch of result sizes) " 1739 << errMsg; 1740 case SubViewVerificationResult::ElemTypeMismatch: 1741 return op.emitError("expected result element type to be ") 1742 << memrefType.getElementType() << errMsg; 1743 case SubViewVerificationResult::MemSpaceMismatch: 1744 return op.emitError("expected result and source memory spaces to match.") 1745 << errMsg; 1746 case SubViewVerificationResult::AffineMapMismatch: 1747 return op.emitError("expected result type to be ") 1748 << expectedType 1749 << " or a rank-reduced version. (mismatch of result affine map) " 1750 << errMsg; 1751 } 1752 llvm_unreachable("unexpected subview verification result"); 1753 } 1754 1755 /// Verifier for SubViewOp. 1756 static LogicalResult verify(SubViewOp op) { 1757 MemRefType baseType = op.getSourceType(); 1758 MemRefType subViewType = op.getType(); 1759 1760 // The base memref and the view memref should be in the same memory space. 1761 if (baseType.getMemorySpace() != subViewType.getMemorySpace()) 1762 return op.emitError("different memory spaces specified for base memref " 1763 "type ") 1764 << baseType << " and subview memref type " << subViewType; 1765 1766 // Verify that the base memref type has a strided layout map. 1767 if (!isStrided(baseType)) 1768 return op.emitError("base type ") << baseType << " is not strided"; 1769 1770 // Verify result type against inferred type. 1771 auto expectedType = SubViewOp::inferResultType( 1772 baseType, extractFromI64ArrayAttr(op.static_offsets()), 1773 extractFromI64ArrayAttr(op.static_sizes()), 1774 extractFromI64ArrayAttr(op.static_strides())); 1775 1776 std::string errMsg; 1777 auto result = isRankReducedType(expectedType, subViewType, &errMsg); 1778 return produceSubViewErrorMsg(result, op, expectedType, errMsg); 1779 } 1780 1781 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { 1782 return os << "range " << range.offset << ":" << range.size << ":" 1783 << range.stride; 1784 } 1785 1786 /// Return the list of Range (i.e. offset, size, stride). Each Range 1787 /// entry contains either the dynamic value or a ConstantIndexOp constructed 1788 /// with `b` at location `loc`. 1789 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, 1790 OpBuilder &b, Location loc) { 1791 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks(); 1792 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); 1793 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); 1794 SmallVector<Range, 8> res; 1795 unsigned rank = ranks[0]; 1796 res.reserve(rank); 1797 for (unsigned idx = 0; idx < rank; ++idx) { 1798 Value offset = 1799 op.isDynamicOffset(idx) 1800 ? op.getDynamicOffset(idx) 1801 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx)); 1802 Value size = op.isDynamicSize(idx) 1803 ? op.getDynamicSize(idx) 1804 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx)); 1805 Value stride = 1806 op.isDynamicStride(idx) 1807 ? op.getDynamicStride(idx) 1808 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx)); 1809 res.emplace_back(Range{offset, size, stride}); 1810 } 1811 return res; 1812 } 1813 1814 namespace { 1815 /// Pattern to rewrite a subview op with MemRefCast arguments. 1816 /// This essentially pushes memref.cast past its consuming subview when 1817 /// `canFoldIntoConsumerOp` is true. 1818 /// 1819 /// Example: 1820 /// ``` 1821 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32> 1822 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : 1823 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]> 1824 /// ``` 1825 /// is rewritten into: 1826 /// ``` 1827 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> 1828 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to 1829 /// memref<3x4xf32, offset:?, strides:[?, 1]> 1830 /// ``` 1831 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> { 1832 public: 1833 using OpRewritePattern<SubViewOp>::OpRewritePattern; 1834 1835 LogicalResult matchAndRewrite(SubViewOp subViewOp, 1836 PatternRewriter &rewriter) const override { 1837 // Any constant operand, just return to let SubViewOpConstantFolder kick in. 1838 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { 1839 return matchPattern(operand, matchConstantIndex()); 1840 })) 1841 return failure(); 1842 1843 auto castOp = subViewOp.source().getDefiningOp<CastOp>(); 1844 if (!castOp) 1845 return failure(); 1846 1847 if (!CastOp::canFoldIntoConsumerOp(castOp)) 1848 return failure(); 1849 1850 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on 1851 /// the cast source operand type and the SubViewOp static information. This 1852 /// is the resulting type if the MemRefCastOp were folded. 1853 auto resultType = SubViewOp::inferRankReducedResultType( 1854 subViewOp.getType().getRank(), 1855 castOp.source().getType().cast<MemRefType>(), 1856 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), 1857 subViewOp.getMixedStrides()); 1858 Value newSubView = rewriter.create<SubViewOp>( 1859 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), 1860 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), 1861 subViewOp.static_sizes(), subViewOp.static_strides()); 1862 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(), 1863 newSubView); 1864 return success(); 1865 } 1866 }; 1867 } // namespace 1868 1869 /// A canonicalizer wrapper to replace SubViewOps. 1870 struct SubViewCanonicalizer { 1871 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { 1872 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType()); 1873 } 1874 }; 1875 1876 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 1877 MLIRContext *context) { 1878 results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder< 1879 SubViewOp, SubViewCanonicalizer>, 1880 SubViewOpMemRefCastFolder>(context); 1881 } 1882 1883 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) { 1884 auto resultShapedType = getResult().getType().cast<ShapedType>(); 1885 auto sourceShapedType = source().getType().cast<ShapedType>(); 1886 1887 if (resultShapedType.hasStaticShape() && 1888 resultShapedType == sourceShapedType) { 1889 return getViewSource(); 1890 } 1891 1892 return {}; 1893 } 1894 1895 //===----------------------------------------------------------------------===// 1896 // TensorLoadOp 1897 //===----------------------------------------------------------------------===// 1898 1899 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) { 1900 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>()) 1901 // Approximate alias analysis by conservatively folding only when no there 1902 // is no interleaved operation. 1903 if (bufferCast->getBlock() == this->getOperation()->getBlock() && 1904 bufferCast->getNextNode() == this->getOperation()) 1905 return bufferCast.tensor(); 1906 return {}; 1907 } 1908 1909 //===----------------------------------------------------------------------===// 1910 // TransposeOp 1911 //===----------------------------------------------------------------------===// 1912 1913 /// Build a strided memref type by applying `permutationMap` tp `memRefType`. 1914 static MemRefType inferTransposeResultType(MemRefType memRefType, 1915 AffineMap permutationMap) { 1916 auto rank = memRefType.getRank(); 1917 auto originalSizes = memRefType.getShape(); 1918 // Compute permuted sizes. 1919 SmallVector<int64_t, 4> sizes(rank, 0); 1920 for (auto en : llvm::enumerate(permutationMap.getResults())) 1921 sizes[en.index()] = 1922 originalSizes[en.value().cast<AffineDimExpr>().getPosition()]; 1923 1924 // Compute permuted strides. 1925 int64_t offset; 1926 SmallVector<int64_t, 4> strides; 1927 auto res = getStridesAndOffset(memRefType, strides, offset); 1928 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank)); 1929 (void)res; 1930 auto map = 1931 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); 1932 map = permutationMap ? map.compose(permutationMap) : map; 1933 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); 1934 } 1935 1936 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, 1937 AffineMapAttr permutation, 1938 ArrayRef<NamedAttribute> attrs) { 1939 auto permutationMap = permutation.getValue(); 1940 assert(permutationMap); 1941 1942 auto memRefType = in.getType().cast<MemRefType>(); 1943 // Compute result type. 1944 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); 1945 1946 build(b, result, resultType, in, attrs); 1947 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); 1948 } 1949 1950 // transpose $in $permutation attr-dict : type($in) `to` type(results) 1951 static void print(OpAsmPrinter &p, TransposeOp op) { 1952 p << "memref.transpose " << op.in() << " " << op.permutation(); 1953 p.printOptionalAttrDict(op->getAttrs(), 1954 {TransposeOp::getPermutationAttrName()}); 1955 p << " : " << op.in().getType() << " to " << op.getType(); 1956 } 1957 1958 static ParseResult parseTransposeOp(OpAsmParser &parser, 1959 OperationState &result) { 1960 OpAsmParser::OperandType in; 1961 AffineMap permutation; 1962 MemRefType srcType, dstType; 1963 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || 1964 parser.parseOptionalAttrDict(result.attributes) || 1965 parser.parseColonType(srcType) || 1966 parser.resolveOperand(in, srcType, result.operands) || 1967 parser.parseKeywordType("to", dstType) || 1968 parser.addTypeToList(dstType, result.types)) 1969 return failure(); 1970 1971 result.addAttribute(TransposeOp::getPermutationAttrName(), 1972 AffineMapAttr::get(permutation)); 1973 return success(); 1974 } 1975 1976 static LogicalResult verify(TransposeOp op) { 1977 if (!op.permutation().isPermutation()) 1978 return op.emitOpError("expected a permutation map"); 1979 if (op.permutation().getNumDims() != op.getShapedType().getRank()) 1980 return op.emitOpError( 1981 "expected a permutation map of same rank as the input"); 1982 1983 auto srcType = op.in().getType().cast<MemRefType>(); 1984 auto dstType = op.getType().cast<MemRefType>(); 1985 auto transposedType = inferTransposeResultType(srcType, op.permutation()); 1986 if (dstType != transposedType) 1987 return op.emitOpError("output type ") 1988 << dstType << " does not match transposed input type " << srcType 1989 << ", " << transposedType; 1990 return success(); 1991 } 1992 1993 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) { 1994 if (succeeded(foldMemRefCast(*this))) 1995 return getResult(); 1996 return {}; 1997 } 1998 1999 //===----------------------------------------------------------------------===// 2000 // ViewOp 2001 //===----------------------------------------------------------------------===// 2002 2003 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { 2004 OpAsmParser::OperandType srcInfo; 2005 SmallVector<OpAsmParser::OperandType, 1> offsetInfo; 2006 SmallVector<OpAsmParser::OperandType, 4> sizesInfo; 2007 auto indexType = parser.getBuilder().getIndexType(); 2008 Type srcType, dstType; 2009 llvm::SMLoc offsetLoc; 2010 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || 2011 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) 2012 return failure(); 2013 2014 if (offsetInfo.size() != 1) 2015 return parser.emitError(offsetLoc) << "expects 1 offset operand"; 2016 2017 return failure( 2018 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || 2019 parser.parseOptionalAttrDict(result.attributes) || 2020 parser.parseColonType(srcType) || 2021 parser.resolveOperand(srcInfo, srcType, result.operands) || 2022 parser.resolveOperands(offsetInfo, indexType, result.operands) || 2023 parser.resolveOperands(sizesInfo, indexType, result.operands) || 2024 parser.parseKeywordType("to", dstType) || 2025 parser.addTypeToList(dstType, result.types)); 2026 } 2027 2028 static void print(OpAsmPrinter &p, ViewOp op) { 2029 p << op.getOperationName() << ' ' << op.getOperand(0) << '['; 2030 p.printOperand(op.byte_shift()); 2031 p << "][" << op.sizes() << ']'; 2032 p.printOptionalAttrDict(op->getAttrs()); 2033 p << " : " << op.getOperand(0).getType() << " to " << op.getType(); 2034 } 2035 2036 static LogicalResult verify(ViewOp op) { 2037 auto baseType = op.getOperand(0).getType().cast<MemRefType>(); 2038 auto viewType = op.getType(); 2039 2040 // The base memref should have identity layout map (or none). 2041 if (baseType.getAffineMaps().size() > 1 || 2042 (baseType.getAffineMaps().size() == 1 && 2043 !baseType.getAffineMaps()[0].isIdentity())) 2044 return op.emitError("unsupported map for base memref type ") << baseType; 2045 2046 // The result memref should have identity layout map (or none). 2047 if (viewType.getAffineMaps().size() > 1 || 2048 (viewType.getAffineMaps().size() == 1 && 2049 !viewType.getAffineMaps()[0].isIdentity())) 2050 return op.emitError("unsupported map for result memref type ") << viewType; 2051 2052 // The base memref and the view memref should be in the same memory space. 2053 if (baseType.getMemorySpace() != viewType.getMemorySpace()) 2054 return op.emitError("different memory spaces specified for base memref " 2055 "type ") 2056 << baseType << " and view memref type " << viewType; 2057 2058 // Verify that we have the correct number of sizes for the result type. 2059 unsigned numDynamicDims = viewType.getNumDynamicDims(); 2060 if (op.sizes().size() != numDynamicDims) 2061 return op.emitError("incorrect number of size operands for type ") 2062 << viewType; 2063 2064 return success(); 2065 } 2066 2067 Value ViewOp::getViewSource() { return source(); } 2068 2069 namespace { 2070 2071 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { 2072 using OpRewritePattern<ViewOp>::OpRewritePattern; 2073 2074 LogicalResult matchAndRewrite(ViewOp viewOp, 2075 PatternRewriter &rewriter) const override { 2076 // Return if none of the operands are constants. 2077 if (llvm::none_of(viewOp.getOperands(), [](Value operand) { 2078 return matchPattern(operand, matchConstantIndex()); 2079 })) 2080 return failure(); 2081 2082 // Get result memref type. 2083 auto memrefType = viewOp.getType(); 2084 2085 // Get offset from old memref view type 'memRefType'. 2086 int64_t oldOffset; 2087 SmallVector<int64_t, 4> oldStrides; 2088 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) 2089 return failure(); 2090 assert(oldOffset == 0 && "Expected 0 offset"); 2091 2092 SmallVector<Value, 4> newOperands; 2093 2094 // Offset cannot be folded into result type. 2095 2096 // Fold any dynamic dim operands which are produced by a constant. 2097 SmallVector<int64_t, 4> newShapeConstants; 2098 newShapeConstants.reserve(memrefType.getRank()); 2099 2100 unsigned dynamicDimPos = 0; 2101 unsigned rank = memrefType.getRank(); 2102 for (unsigned dim = 0, e = rank; dim < e; ++dim) { 2103 int64_t dimSize = memrefType.getDimSize(dim); 2104 // If this is already static dimension, keep it. 2105 if (!ShapedType::isDynamic(dimSize)) { 2106 newShapeConstants.push_back(dimSize); 2107 continue; 2108 } 2109 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); 2110 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 2111 // Dynamic shape dimension will be folded. 2112 newShapeConstants.push_back(constantIndexOp.getValue()); 2113 } else { 2114 // Dynamic shape dimension not folded; copy operand from old memref. 2115 newShapeConstants.push_back(dimSize); 2116 newOperands.push_back(viewOp.sizes()[dynamicDimPos]); 2117 } 2118 dynamicDimPos++; 2119 } 2120 2121 // Create new memref type with constant folded dims. 2122 MemRefType newMemRefType = 2123 MemRefType::Builder(memrefType).setShape(newShapeConstants); 2124 // Nothing new, don't fold. 2125 if (newMemRefType == memrefType) 2126 return failure(); 2127 2128 // Create new ViewOp. 2129 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType, 2130 viewOp.getOperand(0), 2131 viewOp.byte_shift(), newOperands); 2132 // Insert a cast so we have the same type as the old memref type. 2133 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType()); 2134 return success(); 2135 } 2136 }; 2137 2138 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> { 2139 using OpRewritePattern<ViewOp>::OpRewritePattern; 2140 2141 LogicalResult matchAndRewrite(ViewOp viewOp, 2142 PatternRewriter &rewriter) const override { 2143 Value memrefOperand = viewOp.getOperand(0); 2144 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>(); 2145 if (!memrefCastOp) 2146 return failure(); 2147 Value allocOperand = memrefCastOp.getOperand(); 2148 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>(); 2149 if (!allocOp) 2150 return failure(); 2151 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand, 2152 viewOp.byte_shift(), viewOp.sizes()); 2153 return success(); 2154 } 2155 }; 2156 2157 } // end anonymous namespace 2158 2159 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, 2160 MLIRContext *context) { 2161 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context); 2162 } 2163 2164 //===----------------------------------------------------------------------===// 2165 // TableGen'd op method definitions 2166 //===----------------------------------------------------------------------===// 2167 2168 #define GET_OP_CLASSES 2169 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" 2170