1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 /// Tablegen Type Definitions 25 //===----------------------------------------------------------------------===// 26 27 #define GET_TYPEDEF_CLASSES 28 #include "mlir/IR/BuiltinTypes.cpp.inc" 29 30 //===----------------------------------------------------------------------===// 31 /// ComplexType 32 //===----------------------------------------------------------------------===// 33 34 /// Verify the construction of an integer type. 35 LogicalResult ComplexType::verifyConstructionInvariants(Location loc, 36 Type elementType) { 37 if (!elementType.isIntOrFloat()) 38 return emitError(loc, "invalid element type for complex"); 39 return success(); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // Integer Type 44 //===----------------------------------------------------------------------===// 45 46 // static constexpr must have a definition (until in C++17 and inline variable). 47 constexpr unsigned IntegerType::kMaxWidth; 48 49 /// Verify the construction of an integer type. 50 LogicalResult 51 IntegerType::verifyConstructionInvariants(Location loc, unsigned width, 52 SignednessSemantics signedness) { 53 if (width > IntegerType::kMaxWidth) { 54 return emitError(loc) << "integer bitwidth is limited to " 55 << IntegerType::kMaxWidth << " bits"; 56 } 57 return success(); 58 } 59 60 unsigned IntegerType::getWidth() const { return getImpl()->width; } 61 62 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 63 return getImpl()->signedness; 64 } 65 66 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 67 if (!scale) 68 return IntegerType(); 69 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // Float Type 74 //===----------------------------------------------------------------------===// 75 76 unsigned FloatType::getWidth() { 77 if (isa<Float16Type, BFloat16Type>()) 78 return 16; 79 if (isa<Float32Type>()) 80 return 32; 81 if (isa<Float64Type>()) 82 return 64; 83 if (isa<Float80Type>()) 84 return 80; 85 if (isa<Float128Type>()) 86 return 128; 87 llvm_unreachable("unexpected float type"); 88 } 89 90 /// Returns the floating semantics for the given type. 91 const llvm::fltSemantics &FloatType::getFloatSemantics() { 92 if (isa<BFloat16Type>()) 93 return APFloat::BFloat(); 94 if (isa<Float16Type>()) 95 return APFloat::IEEEhalf(); 96 if (isa<Float32Type>()) 97 return APFloat::IEEEsingle(); 98 if (isa<Float64Type>()) 99 return APFloat::IEEEdouble(); 100 if (isa<Float80Type>()) 101 return APFloat::x87DoubleExtended(); 102 if (isa<Float128Type>()) 103 return APFloat::IEEEquad(); 104 llvm_unreachable("non-floating point type used"); 105 } 106 107 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 108 if (!scale) 109 return FloatType(); 110 MLIRContext *ctx = getContext(); 111 if (isF16() || isBF16()) { 112 if (scale == 2) 113 return FloatType::getF32(ctx); 114 if (scale == 4) 115 return FloatType::getF64(ctx); 116 } 117 if (isF32()) 118 if (scale == 2) 119 return FloatType::getF64(ctx); 120 return FloatType(); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // FunctionType 125 //===----------------------------------------------------------------------===// 126 127 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 128 129 ArrayRef<Type> FunctionType::getInputs() const { 130 return getImpl()->getInputs(); 131 } 132 133 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 134 135 ArrayRef<Type> FunctionType::getResults() const { 136 return getImpl()->getResults(); 137 } 138 139 /// Helper to call a callback once on each index in the range 140 /// [0, `totalIndices`), *except* for the indices given in `indices`. 141 /// `indices` is allowed to have duplicates and can be in any order. 142 inline void iterateIndicesExcept(unsigned totalIndices, 143 ArrayRef<unsigned> indices, 144 function_ref<void(unsigned)> callback) { 145 llvm::BitVector skipIndices(totalIndices); 146 for (unsigned i : indices) 147 skipIndices.set(i); 148 149 for (unsigned i = 0; i < totalIndices; ++i) 150 if (!skipIndices.test(i)) 151 callback(i); 152 } 153 154 /// Returns a new function type without the specified arguments and results. 155 FunctionType 156 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 157 ArrayRef<unsigned> resultIndices) { 158 ArrayRef<Type> newInputTypes = getInputs(); 159 SmallVector<Type, 4> newInputTypesBuffer; 160 if (!argIndices.empty()) { 161 unsigned originalNumArgs = getNumInputs(); 162 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 163 newInputTypesBuffer.emplace_back(getInput(i)); 164 }); 165 newInputTypes = newInputTypesBuffer; 166 } 167 168 ArrayRef<Type> newResultTypes = getResults(); 169 SmallVector<Type, 4> newResultTypesBuffer; 170 if (!resultIndices.empty()) { 171 unsigned originalNumResults = getNumResults(); 172 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 173 newResultTypesBuffer.emplace_back(getResult(i)); 174 }); 175 newResultTypes = newResultTypesBuffer; 176 } 177 178 return get(getContext(), newInputTypes, newResultTypes); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // OpaqueType 183 //===----------------------------------------------------------------------===// 184 185 /// Verify the construction of an opaque type. 186 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, 187 Identifier dialect, 188 StringRef typeData) { 189 if (!Dialect::isValidNamespace(dialect.strref())) 190 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 191 return success(); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // ShapedType 196 //===----------------------------------------------------------------------===// 197 constexpr int64_t ShapedType::kDynamicSize; 198 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 199 200 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 201 if (auto other = dyn_cast<MemRefType>()) { 202 MemRefType::Builder b(other); 203 b.setShape(shape); 204 b.setElementType(elementType); 205 return b; 206 } 207 208 if (auto other = dyn_cast<UnrankedMemRefType>()) { 209 MemRefType::Builder b(shape, elementType); 210 b.setMemorySpace(other.getMemorySpace()); 211 return b; 212 } 213 214 if (isa<TensorType>()) 215 return RankedTensorType::get(shape, elementType); 216 217 if (isa<VectorType>()) 218 return VectorType::get(shape, elementType); 219 220 llvm_unreachable("Unhandled ShapedType clone case"); 221 } 222 223 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 224 if (auto other = dyn_cast<MemRefType>()) { 225 MemRefType::Builder b(other); 226 b.setShape(shape); 227 return b; 228 } 229 230 if (auto other = dyn_cast<UnrankedMemRefType>()) { 231 MemRefType::Builder b(shape, other.getElementType()); 232 b.setShape(shape); 233 b.setMemorySpace(other.getMemorySpace()); 234 return b; 235 } 236 237 if (isa<TensorType>()) 238 return RankedTensorType::get(shape, getElementType()); 239 240 if (isa<VectorType>()) 241 return VectorType::get(shape, getElementType()); 242 243 llvm_unreachable("Unhandled ShapedType clone case"); 244 } 245 246 ShapedType ShapedType::clone(Type elementType) { 247 if (auto other = dyn_cast<MemRefType>()) { 248 MemRefType::Builder b(other); 249 b.setElementType(elementType); 250 return b; 251 } 252 253 if (auto other = dyn_cast<UnrankedMemRefType>()) { 254 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 255 } 256 257 if (isa<TensorType>()) { 258 if (hasRank()) 259 return RankedTensorType::get(getShape(), elementType); 260 return UnrankedTensorType::get(elementType); 261 } 262 263 if (isa<VectorType>()) 264 return VectorType::get(getShape(), elementType); 265 266 llvm_unreachable("Unhandled ShapedType clone hit"); 267 } 268 269 Type ShapedType::getElementType() const { 270 return static_cast<ImplType *>(impl)->elementType; 271 } 272 273 unsigned ShapedType::getElementTypeBitWidth() const { 274 return getElementType().getIntOrFloatBitWidth(); 275 } 276 277 int64_t ShapedType::getNumElements() const { 278 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 279 auto shape = getShape(); 280 int64_t num = 1; 281 for (auto dim : shape) { 282 num *= dim; 283 assert(num >= 0 && "integer overflow in element count computation"); 284 } 285 return num; 286 } 287 288 int64_t ShapedType::getRank() const { 289 assert(hasRank() && "cannot query rank of unranked shaped type"); 290 return getShape().size(); 291 } 292 293 bool ShapedType::hasRank() const { 294 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 295 } 296 297 int64_t ShapedType::getDimSize(unsigned idx) const { 298 assert(idx < getRank() && "invalid index for shaped type"); 299 return getShape()[idx]; 300 } 301 302 bool ShapedType::isDynamicDim(unsigned idx) const { 303 assert(idx < getRank() && "invalid index for shaped type"); 304 return isDynamic(getShape()[idx]); 305 } 306 307 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 308 assert(index < getRank() && "invalid index"); 309 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 310 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 311 } 312 313 /// Get the number of bits require to store a value of the given shaped type. 314 /// Compute the value recursively since tensors are allowed to have vectors as 315 /// elements. 316 int64_t ShapedType::getSizeInBits() const { 317 assert(hasStaticShape() && 318 "cannot get the bit size of an aggregate with a dynamic shape"); 319 320 auto elementType = getElementType(); 321 if (elementType.isIntOrFloat()) 322 return elementType.getIntOrFloatBitWidth() * getNumElements(); 323 324 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 325 elementType = complexType.getElementType(); 326 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 327 } 328 329 // Tensors can have vectors and other tensors as elements, other shaped types 330 // cannot. 331 assert(isa<TensorType>() && "unsupported element type"); 332 assert((elementType.isa<VectorType, TensorType>()) && 333 "unsupported tensor element type"); 334 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 335 } 336 337 ArrayRef<int64_t> ShapedType::getShape() const { 338 if (auto vectorType = dyn_cast<VectorType>()) 339 return vectorType.getShape(); 340 if (auto tensorType = dyn_cast<RankedTensorType>()) 341 return tensorType.getShape(); 342 return cast<MemRefType>().getShape(); 343 } 344 345 int64_t ShapedType::getNumDynamicDims() const { 346 return llvm::count_if(getShape(), isDynamic); 347 } 348 349 bool ShapedType::hasStaticShape() const { 350 return hasRank() && llvm::none_of(getShape(), isDynamic); 351 } 352 353 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 354 return hasStaticShape() && getShape() == shape; 355 } 356 357 //===----------------------------------------------------------------------===// 358 // VectorType 359 //===----------------------------------------------------------------------===// 360 361 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { 362 return Base::get(elementType.getContext(), shape, elementType); 363 } 364 365 VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape, 366 Type elementType) { 367 return Base::getChecked(location, shape, elementType); 368 } 369 370 LogicalResult VectorType::verifyConstructionInvariants(Location loc, 371 ArrayRef<int64_t> shape, 372 Type elementType) { 373 if (shape.empty()) 374 return emitError(loc, "vector types must have at least one dimension"); 375 376 if (!isValidElementType(elementType)) 377 return emitError(loc, "vector elements must be int or float type"); 378 379 if (any_of(shape, [](int64_t i) { return i <= 0; })) 380 return emitError(loc, "vector types must have positive constant sizes"); 381 382 return success(); 383 } 384 385 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } 386 387 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 388 if (!scale) 389 return VectorType(); 390 if (auto et = getElementType().dyn_cast<IntegerType>()) 391 if (auto scaledEt = et.scaleElementBitwidth(scale)) 392 return VectorType::get(getShape(), scaledEt); 393 if (auto et = getElementType().dyn_cast<FloatType>()) 394 if (auto scaledEt = et.scaleElementBitwidth(scale)) 395 return VectorType::get(getShape(), scaledEt); 396 return VectorType(); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // TensorType 401 //===----------------------------------------------------------------------===// 402 403 // Check if "elementType" can be an element type of a tensor. Emit errors if 404 // location is not nullptr. Returns failure if check failed. 405 static LogicalResult checkTensorElementType(Location location, 406 Type elementType) { 407 if (!TensorType::isValidElementType(elementType)) 408 return emitError(location, "invalid tensor element type: ") << elementType; 409 return success(); 410 } 411 412 /// Return true if the specified element type is ok in a tensor. 413 bool TensorType::isValidElementType(Type type) { 414 // Note: Non standard/builtin types are allowed to exist within tensor 415 // types. Dialects are expected to verify that tensor types have a valid 416 // element type within that dialect. 417 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 418 IndexType>() || 419 !type.getDialect().getNamespace().empty(); 420 } 421 422 //===----------------------------------------------------------------------===// 423 // RankedTensorType 424 //===----------------------------------------------------------------------===// 425 426 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, 427 Type elementType) { 428 return Base::get(elementType.getContext(), shape, elementType); 429 } 430 431 RankedTensorType RankedTensorType::getChecked(Location location, 432 ArrayRef<int64_t> shape, 433 Type elementType) { 434 return Base::getChecked(location, shape, elementType); 435 } 436 437 LogicalResult RankedTensorType::verifyConstructionInvariants( 438 Location loc, ArrayRef<int64_t> shape, Type elementType) { 439 for (int64_t s : shape) { 440 if (s < -1) 441 return emitError(loc, "invalid tensor dimension size"); 442 } 443 return checkTensorElementType(loc, elementType); 444 } 445 446 ArrayRef<int64_t> RankedTensorType::getShape() const { 447 return getImpl()->getShape(); 448 } 449 450 //===----------------------------------------------------------------------===// 451 // UnrankedTensorType 452 //===----------------------------------------------------------------------===// 453 454 UnrankedTensorType UnrankedTensorType::get(Type elementType) { 455 return Base::get(elementType.getContext(), elementType); 456 } 457 458 UnrankedTensorType UnrankedTensorType::getChecked(Location location, 459 Type elementType) { 460 return Base::getChecked(location, elementType); 461 } 462 463 LogicalResult 464 UnrankedTensorType::verifyConstructionInvariants(Location loc, 465 Type elementType) { 466 return checkTensorElementType(loc, elementType); 467 } 468 469 //===----------------------------------------------------------------------===// 470 // BaseMemRefType 471 //===----------------------------------------------------------------------===// 472 473 unsigned BaseMemRefType::getMemorySpace() const { 474 return static_cast<ImplType *>(impl)->memorySpace; 475 } 476 477 //===----------------------------------------------------------------------===// 478 // MemRefType 479 //===----------------------------------------------------------------------===// 480 481 /// Get or create a new MemRefType based on shape, element type, affine 482 /// map composition, and memory space. Assumes the arguments define a 483 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 484 /// construction failures. 485 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 486 ArrayRef<AffineMap> affineMapComposition, 487 unsigned memorySpace) { 488 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, 489 /*location=*/llvm::None); 490 assert(result && "Failed to construct instance of MemRefType."); 491 return result; 492 } 493 494 /// Get or create a new MemRefType based on shape, element type, affine 495 /// map composition, and memory space declared at the given location. 496 /// If the location is unknown, the last argument should be an instance of 497 /// UnknownLoc. If the MemRefType defined by the arguments would be 498 /// ill-formed, emits errors (to the handler registered with the context or to 499 /// the error stream) and returns nullptr. 500 MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape, 501 Type elementType, 502 ArrayRef<AffineMap> affineMapComposition, 503 unsigned memorySpace) { 504 return getImpl(shape, elementType, affineMapComposition, memorySpace, 505 location); 506 } 507 508 /// Get or create a new MemRefType defined by the arguments. If the resulting 509 /// type would be ill-formed, return nullptr. If the location is provided, 510 /// emit detailed error messages. To emit errors when the location is unknown, 511 /// pass in an instance of UnknownLoc. 512 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, 513 ArrayRef<AffineMap> affineMapComposition, 514 unsigned memorySpace, 515 Optional<Location> location) { 516 auto *context = elementType.getContext(); 517 518 if (!BaseMemRefType::isValidElementType(elementType)) 519 return (void)emitOptionalError(location, "invalid memref element type"), 520 MemRefType(); 521 522 for (int64_t s : shape) { 523 // Negative sizes are not allowed except for `-1` that means dynamic size. 524 if (s < -1) 525 return (void)emitOptionalError(location, "invalid memref size"), 526 MemRefType(); 527 } 528 529 // Check that the structure of the composition is valid, i.e. that each 530 // subsequent affine map has as many inputs as the previous map has results. 531 // Take the dimensionality of the MemRef for the first map. 532 auto dim = shape.size(); 533 unsigned i = 0; 534 for (const auto &affineMap : affineMapComposition) { 535 if (affineMap.getNumDims() != dim) { 536 if (location) 537 emitError(*location) 538 << "memref affine map dimension mismatch between " 539 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 540 << " and affine map" << i + 1 << ": " << dim 541 << " != " << affineMap.getNumDims(); 542 return nullptr; 543 } 544 545 dim = affineMap.getNumResults(); 546 ++i; 547 } 548 549 // Drop identity maps from the composition. 550 // This may lead to the composition becoming empty, which is interpreted as an 551 // implicit identity. 552 SmallVector<AffineMap, 2> cleanedAffineMapComposition; 553 for (const auto &map : affineMapComposition) { 554 if (map.isIdentity()) 555 continue; 556 cleanedAffineMapComposition.push_back(map); 557 } 558 559 return Base::get(context, shape, elementType, cleanedAffineMapComposition, 560 memorySpace); 561 } 562 563 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } 564 565 ArrayRef<AffineMap> MemRefType::getAffineMaps() const { 566 return getImpl()->getAffineMaps(); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // UnrankedMemRefType 571 //===----------------------------------------------------------------------===// 572 573 UnrankedMemRefType UnrankedMemRefType::get(Type elementType, 574 unsigned memorySpace) { 575 return Base::get(elementType.getContext(), elementType, memorySpace); 576 } 577 578 UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, 579 Type elementType, 580 unsigned memorySpace) { 581 return Base::getChecked(location, elementType, memorySpace); 582 } 583 584 LogicalResult 585 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, 586 unsigned memorySpace) { 587 if (!BaseMemRefType::isValidElementType(elementType)) 588 return emitError(loc, "invalid memref element type"); 589 return success(); 590 } 591 592 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 593 // i.e. single term). Accumulate the AffineExpr into the existing one. 594 static void extractStridesFromTerm(AffineExpr e, 595 AffineExpr multiplicativeFactor, 596 MutableArrayRef<AffineExpr> strides, 597 AffineExpr &offset) { 598 if (auto dim = e.dyn_cast<AffineDimExpr>()) 599 strides[dim.getPosition()] = 600 strides[dim.getPosition()] + multiplicativeFactor; 601 else 602 offset = offset + e * multiplicativeFactor; 603 } 604 605 /// Takes a single AffineExpr `e` and populates the `strides` array with the 606 /// strides expressions for each dim position. 607 /// The convention is that the strides for dimensions d0, .. dn appear in 608 /// order to make indexing intuitive into the result. 609 static LogicalResult extractStrides(AffineExpr e, 610 AffineExpr multiplicativeFactor, 611 MutableArrayRef<AffineExpr> strides, 612 AffineExpr &offset) { 613 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 614 if (!bin) { 615 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 616 return success(); 617 } 618 619 if (bin.getKind() == AffineExprKind::CeilDiv || 620 bin.getKind() == AffineExprKind::FloorDiv || 621 bin.getKind() == AffineExprKind::Mod) 622 return failure(); 623 624 if (bin.getKind() == AffineExprKind::Mul) { 625 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 626 if (dim) { 627 strides[dim.getPosition()] = 628 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 629 return success(); 630 } 631 // LHS and RHS may both contain complex expressions of dims. Try one path 632 // and if it fails try the other. This is guaranteed to succeed because 633 // only one path may have a `dim`, otherwise this is not an AffineExpr in 634 // the first place. 635 if (bin.getLHS().isSymbolicOrConstant()) 636 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 637 strides, offset); 638 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 639 strides, offset); 640 } 641 642 if (bin.getKind() == AffineExprKind::Add) { 643 auto res1 = 644 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 645 auto res2 = 646 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 647 return success(succeeded(res1) && succeeded(res2)); 648 } 649 650 llvm_unreachable("unexpected binary operation"); 651 } 652 653 LogicalResult mlir::getStridesAndOffset(MemRefType t, 654 SmallVectorImpl<AffineExpr> &strides, 655 AffineExpr &offset) { 656 auto affineMaps = t.getAffineMaps(); 657 // For now strides are only computed on a single affine map with a single 658 // result (i.e. the closed subset of linearization maps that are compatible 659 // with striding semantics). 660 // TODO: support more forms on a per-need basis. 661 if (affineMaps.size() > 1) 662 return failure(); 663 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 664 return failure(); 665 666 auto zero = getAffineConstantExpr(0, t.getContext()); 667 auto one = getAffineConstantExpr(1, t.getContext()); 668 offset = zero; 669 strides.assign(t.getRank(), zero); 670 671 AffineMap m; 672 if (!affineMaps.empty()) { 673 m = affineMaps.front(); 674 assert(!m.isIdentity() && "unexpected identity map"); 675 } 676 677 // Canonical case for empty map. 678 if (!m) { 679 // 0-D corner case, offset is already 0. 680 if (t.getRank() == 0) 681 return success(); 682 auto stridedExpr = 683 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 684 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 685 return success(); 686 assert(false && "unexpected failure: extract strides in canonical layout"); 687 } 688 689 // Non-canonical case requires more work. 690 auto stridedExpr = 691 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 692 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 693 offset = AffineExpr(); 694 strides.clear(); 695 return failure(); 696 } 697 698 // Simplify results to allow folding to constants and simple checks. 699 unsigned numDims = m.getNumDims(); 700 unsigned numSymbols = m.getNumSymbols(); 701 offset = simplifyAffineExpr(offset, numDims, numSymbols); 702 for (auto &stride : strides) 703 stride = simplifyAffineExpr(stride, numDims, numSymbols); 704 705 /// In practice, a strided memref must be internally non-aliasing. Test 706 /// against 0 as a proxy. 707 /// TODO: static cases can have more advanced checks. 708 /// TODO: dynamic cases would require a way to compare symbolic 709 /// expressions and would probably need an affine set context propagated 710 /// everywhere. 711 if (llvm::any_of(strides, [](AffineExpr e) { 712 return e == getAffineConstantExpr(0, e.getContext()); 713 })) { 714 offset = AffineExpr(); 715 strides.clear(); 716 return failure(); 717 } 718 719 return success(); 720 } 721 722 LogicalResult mlir::getStridesAndOffset(MemRefType t, 723 SmallVectorImpl<int64_t> &strides, 724 int64_t &offset) { 725 AffineExpr offsetExpr; 726 SmallVector<AffineExpr, 4> strideExprs; 727 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 728 return failure(); 729 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 730 offset = cst.getValue(); 731 else 732 offset = ShapedType::kDynamicStrideOrOffset; 733 for (auto e : strideExprs) { 734 if (auto c = e.dyn_cast<AffineConstantExpr>()) 735 strides.push_back(c.getValue()); 736 else 737 strides.push_back(ShapedType::kDynamicStrideOrOffset); 738 } 739 return success(); 740 } 741 742 //===----------------------------------------------------------------------===// 743 /// TupleType 744 //===----------------------------------------------------------------------===// 745 746 /// Return the elements types for this tuple. 747 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 748 749 /// Accumulate the types contained in this tuple and tuples nested within it. 750 /// Note that this only flattens nested tuples, not any other container type, 751 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 752 /// (i32, tensor<i32>, f32, i64) 753 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 754 for (Type type : getTypes()) { 755 if (auto nestedTuple = type.dyn_cast<TupleType>()) 756 nestedTuple.getFlattenedTypes(types); 757 else 758 types.push_back(type); 759 } 760 } 761 762 /// Return the number of element types. 763 size_t TupleType::size() const { return getImpl()->size(); } 764 765 //===----------------------------------------------------------------------===// 766 // Type Utilities 767 //===----------------------------------------------------------------------===// 768 769 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 770 int64_t offset, 771 MLIRContext *context) { 772 AffineExpr expr; 773 unsigned nSymbols = 0; 774 775 // AffineExpr for offset. 776 // Static case. 777 if (offset != MemRefType::getDynamicStrideOrOffset()) { 778 auto cst = getAffineConstantExpr(offset, context); 779 expr = cst; 780 } else { 781 // Dynamic case, new symbol for the offset. 782 auto sym = getAffineSymbolExpr(nSymbols++, context); 783 expr = sym; 784 } 785 786 // AffineExpr for strides. 787 for (auto en : llvm::enumerate(strides)) { 788 auto dim = en.index(); 789 auto stride = en.value(); 790 assert(stride != 0 && "Invalid stride specification"); 791 auto d = getAffineDimExpr(dim, context); 792 AffineExpr mult; 793 // Static case. 794 if (stride != MemRefType::getDynamicStrideOrOffset()) 795 mult = getAffineConstantExpr(stride, context); 796 else 797 // Dynamic case, new symbol for each new stride. 798 mult = getAffineSymbolExpr(nSymbols++, context); 799 expr = expr + d * mult; 800 } 801 802 return AffineMap::get(strides.size(), nSymbols, expr); 803 } 804 805 /// Return a version of `t` with identity layout if it can be determined 806 /// statically that the layout is the canonical contiguous strided layout. 807 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 808 /// `t` with simplified layout. 809 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 810 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 811 auto affineMaps = t.getAffineMaps(); 812 // Already in canonical form. 813 if (affineMaps.empty()) 814 return t; 815 816 // Can't reduce to canonical identity form, return in canonical form. 817 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 818 return t; 819 820 // Corner-case for 0-D affine maps. 821 auto m = affineMaps[0]; 822 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 823 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 824 if (cst.getValue() == 0) 825 return MemRefType::Builder(t).setAffineMaps({}); 826 return t; 827 } 828 829 // 0-D corner case for empty shape that still have an affine map. Example: 830 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 831 // offset needs to remain, just return t. 832 if (t.getShape().empty()) 833 return t; 834 835 // If the canonical strided layout for the sizes of `t` is equal to the 836 // simplified layout of `t` we can just return an empty layout. Otherwise, 837 // just simplify the existing layout. 838 AffineExpr expr = 839 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 840 auto simplifiedLayoutExpr = 841 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 842 if (expr != simplifiedLayoutExpr) 843 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 844 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 845 return MemRefType::Builder(t).setAffineMaps({}); 846 } 847 848 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 849 ArrayRef<AffineExpr> exprs, 850 MLIRContext *context) { 851 assert(!sizes.empty() && !exprs.empty() && 852 "expected non-empty sizes and exprs"); 853 854 // Size 0 corner case is useful for canonicalizations. 855 if (llvm::is_contained(sizes, 0)) 856 return getAffineConstantExpr(0, context); 857 858 auto maps = AffineMap::inferFromExprList(exprs); 859 assert(!maps.empty() && "Expected one non-empty map"); 860 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 861 862 AffineExpr expr; 863 bool dynamicPoisonBit = false; 864 int64_t runningSize = 1; 865 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 866 int64_t size = std::get<1>(en); 867 // Degenerate case, no size =-> no stride 868 if (size == 0) 869 continue; 870 AffineExpr dimExpr = std::get<0>(en); 871 AffineExpr stride = dynamicPoisonBit 872 ? getAffineSymbolExpr(nSymbols++, context) 873 : getAffineConstantExpr(runningSize, context); 874 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 875 if (size > 0) { 876 runningSize *= size; 877 assert(runningSize > 0 && "integer overflow in size computation"); 878 } else { 879 dynamicPoisonBit = true; 880 } 881 } 882 return simplifyAffineExpr(expr, numDims, nSymbols); 883 } 884 885 /// Return a version of `t` with a layout that has all dynamic offset and 886 /// strides. This is used to erase the static layout. 887 MemRefType mlir::eraseStridedLayout(MemRefType t) { 888 auto val = ShapedType::kDynamicStrideOrOffset; 889 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 890 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 891 } 892 893 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 894 MLIRContext *context) { 895 SmallVector<AffineExpr, 4> exprs; 896 exprs.reserve(sizes.size()); 897 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 898 exprs.push_back(getAffineDimExpr(dim, context)); 899 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 900 } 901 902 /// Return true if the layout for `t` is compatible with strided semantics. 903 bool mlir::isStrided(MemRefType t) { 904 int64_t offset; 905 SmallVector<int64_t, 4> strides; 906 auto res = getStridesAndOffset(t, strides, offset); 907 return succeeded(res); 908 } 909 910 /// Return the layout map in strided linear layout AffineMap form. 911 /// Return null if the layout is not compatible with a strided layout. 912 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 913 int64_t offset; 914 SmallVector<int64_t, 4> strides; 915 if (failed(getStridesAndOffset(t, strides, offset))) 916 return AffineMap(); 917 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 918 } 919