1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::detail; 23 24 //===----------------------------------------------------------------------===// 25 /// Tablegen Type Definitions 26 //===----------------------------------------------------------------------===// 27 28 #define GET_TYPEDEF_CLASSES 29 #include "mlir/IR/BuiltinTypes.cpp.inc" 30 31 //===----------------------------------------------------------------------===// 32 /// ComplexType 33 //===----------------------------------------------------------------------===// 34 35 /// Verify the construction of an integer type. 36 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 37 Type elementType) { 38 if (!elementType.isIntOrFloat()) 39 return emitError() << "invalid element type for complex"; 40 return success(); 41 } 42 43 //===----------------------------------------------------------------------===// 44 // Integer Type 45 //===----------------------------------------------------------------------===// 46 47 // static constexpr must have a definition (until in C++17 and inline variable). 48 constexpr unsigned IntegerType::kMaxWidth; 49 50 /// Verify the construction of an integer type. 51 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 52 unsigned width, 53 SignednessSemantics signedness) { 54 if (width > IntegerType::kMaxWidth) { 55 return emitError() << "integer bitwidth is limited to " 56 << IntegerType::kMaxWidth << " bits"; 57 } 58 return success(); 59 } 60 61 unsigned IntegerType::getWidth() const { return getImpl()->width; } 62 63 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 64 return getImpl()->signedness; 65 } 66 67 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 68 if (!scale) 69 return IntegerType(); 70 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 71 } 72 73 //===----------------------------------------------------------------------===// 74 // Float Type 75 //===----------------------------------------------------------------------===// 76 77 unsigned FloatType::getWidth() { 78 if (isa<Float16Type, BFloat16Type>()) 79 return 16; 80 if (isa<Float32Type>()) 81 return 32; 82 if (isa<Float64Type>()) 83 return 64; 84 if (isa<Float80Type>()) 85 return 80; 86 if (isa<Float128Type>()) 87 return 128; 88 llvm_unreachable("unexpected float type"); 89 } 90 91 /// Returns the floating semantics for the given type. 92 const llvm::fltSemantics &FloatType::getFloatSemantics() { 93 if (isa<BFloat16Type>()) 94 return APFloat::BFloat(); 95 if (isa<Float16Type>()) 96 return APFloat::IEEEhalf(); 97 if (isa<Float32Type>()) 98 return APFloat::IEEEsingle(); 99 if (isa<Float64Type>()) 100 return APFloat::IEEEdouble(); 101 if (isa<Float80Type>()) 102 return APFloat::x87DoubleExtended(); 103 if (isa<Float128Type>()) 104 return APFloat::IEEEquad(); 105 llvm_unreachable("non-floating point type used"); 106 } 107 108 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 109 if (!scale) 110 return FloatType(); 111 MLIRContext *ctx = getContext(); 112 if (isF16() || isBF16()) { 113 if (scale == 2) 114 return FloatType::getF32(ctx); 115 if (scale == 4) 116 return FloatType::getF64(ctx); 117 } 118 if (isF32()) 119 if (scale == 2) 120 return FloatType::getF64(ctx); 121 return FloatType(); 122 } 123 124 //===----------------------------------------------------------------------===// 125 // FunctionType 126 //===----------------------------------------------------------------------===// 127 128 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 129 130 ArrayRef<Type> FunctionType::getInputs() const { 131 return getImpl()->getInputs(); 132 } 133 134 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 135 136 ArrayRef<Type> FunctionType::getResults() const { 137 return getImpl()->getResults(); 138 } 139 140 /// Helper to call a callback once on each index in the range 141 /// [0, `totalIndices`), *except* for the indices given in `indices`. 142 /// `indices` is allowed to have duplicates and can be in any order. 143 inline void iterateIndicesExcept(unsigned totalIndices, 144 ArrayRef<unsigned> indices, 145 function_ref<void(unsigned)> callback) { 146 llvm::BitVector skipIndices(totalIndices); 147 for (unsigned i : indices) 148 skipIndices.set(i); 149 150 for (unsigned i = 0; i < totalIndices; ++i) 151 if (!skipIndices.test(i)) 152 callback(i); 153 } 154 155 /// Returns a new function type without the specified arguments and results. 156 FunctionType 157 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 158 ArrayRef<unsigned> resultIndices) { 159 ArrayRef<Type> newInputTypes = getInputs(); 160 SmallVector<Type, 4> newInputTypesBuffer; 161 if (!argIndices.empty()) { 162 unsigned originalNumArgs = getNumInputs(); 163 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 164 newInputTypesBuffer.emplace_back(getInput(i)); 165 }); 166 newInputTypes = newInputTypesBuffer; 167 } 168 169 ArrayRef<Type> newResultTypes = getResults(); 170 SmallVector<Type, 4> newResultTypesBuffer; 171 if (!resultIndices.empty()) { 172 unsigned originalNumResults = getNumResults(); 173 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 174 newResultTypesBuffer.emplace_back(getResult(i)); 175 }); 176 newResultTypes = newResultTypesBuffer; 177 } 178 179 return get(getContext(), newInputTypes, newResultTypes); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // OpaqueType 184 //===----------------------------------------------------------------------===// 185 186 /// Verify the construction of an opaque type. 187 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 188 Identifier dialect, StringRef typeData) { 189 if (!Dialect::isValidNamespace(dialect.strref())) 190 return emitError() << "invalid dialect namespace '" << dialect << "'"; 191 return success(); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // ShapedType 196 //===----------------------------------------------------------------------===// 197 constexpr int64_t ShapedType::kDynamicSize; 198 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 199 200 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 201 if (auto other = dyn_cast<MemRefType>()) { 202 MemRefType::Builder b(other); 203 b.setShape(shape); 204 b.setElementType(elementType); 205 return b; 206 } 207 208 if (auto other = dyn_cast<UnrankedMemRefType>()) { 209 MemRefType::Builder b(shape, elementType); 210 b.setMemorySpace(other.getMemorySpaceAsInt()); 211 return b; 212 } 213 214 if (isa<TensorType>()) 215 return RankedTensorType::get(shape, elementType); 216 217 if (isa<VectorType>()) 218 return VectorType::get(shape, elementType); 219 220 llvm_unreachable("Unhandled ShapedType clone case"); 221 } 222 223 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 224 if (auto other = dyn_cast<MemRefType>()) { 225 MemRefType::Builder b(other); 226 b.setShape(shape); 227 return b; 228 } 229 230 if (auto other = dyn_cast<UnrankedMemRefType>()) { 231 MemRefType::Builder b(shape, other.getElementType()); 232 b.setShape(shape); 233 b.setMemorySpace(other.getMemorySpaceAsInt()); 234 return b; 235 } 236 237 if (isa<TensorType>()) 238 return RankedTensorType::get(shape, getElementType()); 239 240 if (isa<VectorType>()) 241 return VectorType::get(shape, getElementType()); 242 243 llvm_unreachable("Unhandled ShapedType clone case"); 244 } 245 246 ShapedType ShapedType::clone(Type elementType) { 247 if (auto other = dyn_cast<MemRefType>()) { 248 MemRefType::Builder b(other); 249 b.setElementType(elementType); 250 return b; 251 } 252 253 if (auto other = dyn_cast<UnrankedMemRefType>()) { 254 return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt()); 255 } 256 257 if (isa<TensorType>()) { 258 if (hasRank()) 259 return RankedTensorType::get(getShape(), elementType); 260 return UnrankedTensorType::get(elementType); 261 } 262 263 if (isa<VectorType>()) 264 return VectorType::get(getShape(), elementType); 265 266 llvm_unreachable("Unhandled ShapedType clone hit"); 267 } 268 269 Type ShapedType::getElementType() const { 270 return TypeSwitch<Type, Type>(*this) 271 .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType, 272 UnrankedMemRefType>([](auto ty) { return ty.getElementType(); }); 273 } 274 275 unsigned ShapedType::getElementTypeBitWidth() const { 276 return getElementType().getIntOrFloatBitWidth(); 277 } 278 279 int64_t ShapedType::getNumElements() const { 280 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 281 auto shape = getShape(); 282 int64_t num = 1; 283 for (auto dim : shape) { 284 num *= dim; 285 assert(num >= 0 && "integer overflow in element count computation"); 286 } 287 return num; 288 } 289 290 int64_t ShapedType::getRank() const { 291 assert(hasRank() && "cannot query rank of unranked shaped type"); 292 return getShape().size(); 293 } 294 295 bool ShapedType::hasRank() const { 296 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 297 } 298 299 int64_t ShapedType::getDimSize(unsigned idx) const { 300 assert(idx < getRank() && "invalid index for shaped type"); 301 return getShape()[idx]; 302 } 303 304 bool ShapedType::isDynamicDim(unsigned idx) const { 305 assert(idx < getRank() && "invalid index for shaped type"); 306 return isDynamic(getShape()[idx]); 307 } 308 309 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 310 assert(index < getRank() && "invalid index"); 311 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 312 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 313 } 314 315 /// Get the number of bits require to store a value of the given shaped type. 316 /// Compute the value recursively since tensors are allowed to have vectors as 317 /// elements. 318 int64_t ShapedType::getSizeInBits() const { 319 assert(hasStaticShape() && 320 "cannot get the bit size of an aggregate with a dynamic shape"); 321 322 auto elementType = getElementType(); 323 if (elementType.isIntOrFloat()) 324 return elementType.getIntOrFloatBitWidth() * getNumElements(); 325 326 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 327 elementType = complexType.getElementType(); 328 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 329 } 330 331 // Tensors can have vectors and other tensors as elements, other shaped types 332 // cannot. 333 assert(isa<TensorType>() && "unsupported element type"); 334 assert((elementType.isa<VectorType, TensorType>()) && 335 "unsupported tensor element type"); 336 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 337 } 338 339 ArrayRef<int64_t> ShapedType::getShape() const { 340 if (auto vectorType = dyn_cast<VectorType>()) 341 return vectorType.getShape(); 342 if (auto tensorType = dyn_cast<RankedTensorType>()) 343 return tensorType.getShape(); 344 return cast<MemRefType>().getShape(); 345 } 346 347 int64_t ShapedType::getNumDynamicDims() const { 348 return llvm::count_if(getShape(), isDynamic); 349 } 350 351 bool ShapedType::hasStaticShape() const { 352 return hasRank() && llvm::none_of(getShape(), isDynamic); 353 } 354 355 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 356 return hasStaticShape() && getShape() == shape; 357 } 358 359 //===----------------------------------------------------------------------===// 360 // VectorType 361 //===----------------------------------------------------------------------===// 362 363 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 364 ArrayRef<int64_t> shape, Type elementType) { 365 if (shape.empty()) 366 return emitError() << "vector types must have at least one dimension"; 367 368 if (!isValidElementType(elementType)) 369 return emitError() << "vector elements must be int or float type"; 370 371 if (any_of(shape, [](int64_t i) { return i <= 0; })) 372 return emitError() << "vector types must have positive constant sizes"; 373 374 return success(); 375 } 376 377 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 378 if (!scale) 379 return VectorType(); 380 if (auto et = getElementType().dyn_cast<IntegerType>()) 381 if (auto scaledEt = et.scaleElementBitwidth(scale)) 382 return VectorType::get(getShape(), scaledEt); 383 if (auto et = getElementType().dyn_cast<FloatType>()) 384 if (auto scaledEt = et.scaleElementBitwidth(scale)) 385 return VectorType::get(getShape(), scaledEt); 386 return VectorType(); 387 } 388 389 //===----------------------------------------------------------------------===// 390 // TensorType 391 //===----------------------------------------------------------------------===// 392 393 // Check if "elementType" can be an element type of a tensor. 394 static LogicalResult 395 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 396 Type elementType) { 397 if (!TensorType::isValidElementType(elementType)) 398 return emitError() << "invalid tensor element type: " << elementType; 399 return success(); 400 } 401 402 /// Return true if the specified element type is ok in a tensor. 403 bool TensorType::isValidElementType(Type type) { 404 // Note: Non standard/builtin types are allowed to exist within tensor 405 // types. Dialects are expected to verify that tensor types have a valid 406 // element type within that dialect. 407 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 408 IndexType>() || 409 !type.getDialect().getNamespace().empty(); 410 } 411 412 //===----------------------------------------------------------------------===// 413 // RankedTensorType 414 //===----------------------------------------------------------------------===// 415 416 LogicalResult 417 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 418 ArrayRef<int64_t> shape, Type elementType) { 419 for (int64_t s : shape) 420 if (s < -1) 421 return emitError() << "invalid tensor dimension size"; 422 return checkTensorElementType(emitError, elementType); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // UnrankedTensorType 427 //===----------------------------------------------------------------------===// 428 429 LogicalResult 430 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 431 Type elementType) { 432 return checkTensorElementType(emitError, elementType); 433 } 434 435 //===----------------------------------------------------------------------===// 436 // BaseMemRefType 437 //===----------------------------------------------------------------------===// 438 439 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 440 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 441 return rankedMemRefTy.getMemorySpaceAsInt(); 442 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 443 } 444 445 //===----------------------------------------------------------------------===// 446 // MemRefType 447 //===----------------------------------------------------------------------===// 448 449 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 450 ArrayRef<int64_t> shape, Type elementType, 451 ArrayRef<AffineMap> affineMapComposition, 452 unsigned memorySpace) { 453 if (!BaseMemRefType::isValidElementType(elementType)) 454 return emitError() << "invalid memref element type"; 455 456 // Negative sizes are not allowed except for `-1` that means dynamic size. 457 for (int64_t s : shape) 458 if (s < -1) 459 return emitError() << "invalid memref size"; 460 461 // Check that the structure of the composition is valid, i.e. that each 462 // subsequent affine map has as many inputs as the previous map has results. 463 // Take the dimensionality of the MemRef for the first map. 464 size_t dim = shape.size(); 465 for (auto it : llvm::enumerate(affineMapComposition)) { 466 AffineMap map = it.value(); 467 if (map.getNumDims() == dim) { 468 dim = map.getNumResults(); 469 continue; 470 } 471 return emitError() << "memref affine map dimension mismatch between " 472 << (it.index() == 0 ? Twine("memref rank") 473 : "affine map " + Twine(it.index())) 474 << " and affine map" << it.index() + 1 << ": " << dim 475 << " != " << map.getNumDims(); 476 } 477 return success(); 478 } 479 480 //===----------------------------------------------------------------------===// 481 // UnrankedMemRefType 482 //===----------------------------------------------------------------------===// 483 484 LogicalResult 485 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 486 Type elementType, unsigned memorySpace) { 487 if (!BaseMemRefType::isValidElementType(elementType)) 488 return emitError() << "invalid memref element type"; 489 return success(); 490 } 491 492 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 493 // i.e. single term). Accumulate the AffineExpr into the existing one. 494 static void extractStridesFromTerm(AffineExpr e, 495 AffineExpr multiplicativeFactor, 496 MutableArrayRef<AffineExpr> strides, 497 AffineExpr &offset) { 498 if (auto dim = e.dyn_cast<AffineDimExpr>()) 499 strides[dim.getPosition()] = 500 strides[dim.getPosition()] + multiplicativeFactor; 501 else 502 offset = offset + e * multiplicativeFactor; 503 } 504 505 /// Takes a single AffineExpr `e` and populates the `strides` array with the 506 /// strides expressions for each dim position. 507 /// The convention is that the strides for dimensions d0, .. dn appear in 508 /// order to make indexing intuitive into the result. 509 static LogicalResult extractStrides(AffineExpr e, 510 AffineExpr multiplicativeFactor, 511 MutableArrayRef<AffineExpr> strides, 512 AffineExpr &offset) { 513 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 514 if (!bin) { 515 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 516 return success(); 517 } 518 519 if (bin.getKind() == AffineExprKind::CeilDiv || 520 bin.getKind() == AffineExprKind::FloorDiv || 521 bin.getKind() == AffineExprKind::Mod) 522 return failure(); 523 524 if (bin.getKind() == AffineExprKind::Mul) { 525 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 526 if (dim) { 527 strides[dim.getPosition()] = 528 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 529 return success(); 530 } 531 // LHS and RHS may both contain complex expressions of dims. Try one path 532 // and if it fails try the other. This is guaranteed to succeed because 533 // only one path may have a `dim`, otherwise this is not an AffineExpr in 534 // the first place. 535 if (bin.getLHS().isSymbolicOrConstant()) 536 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 537 strides, offset); 538 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 539 strides, offset); 540 } 541 542 if (bin.getKind() == AffineExprKind::Add) { 543 auto res1 = 544 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 545 auto res2 = 546 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 547 return success(succeeded(res1) && succeeded(res2)); 548 } 549 550 llvm_unreachable("unexpected binary operation"); 551 } 552 553 LogicalResult mlir::getStridesAndOffset(MemRefType t, 554 SmallVectorImpl<AffineExpr> &strides, 555 AffineExpr &offset) { 556 auto affineMaps = t.getAffineMaps(); 557 // For now strides are only computed on a single affine map with a single 558 // result (i.e. the closed subset of linearization maps that are compatible 559 // with striding semantics). 560 // TODO: support more forms on a per-need basis. 561 if (affineMaps.size() > 1) 562 return failure(); 563 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 564 return failure(); 565 566 auto zero = getAffineConstantExpr(0, t.getContext()); 567 auto one = getAffineConstantExpr(1, t.getContext()); 568 offset = zero; 569 strides.assign(t.getRank(), zero); 570 571 AffineMap m; 572 if (!affineMaps.empty()) { 573 m = affineMaps.front(); 574 assert(!m.isIdentity() && "unexpected identity map"); 575 } 576 577 // Canonical case for empty map. 578 if (!m) { 579 // 0-D corner case, offset is already 0. 580 if (t.getRank() == 0) 581 return success(); 582 auto stridedExpr = 583 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 584 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 585 return success(); 586 assert(false && "unexpected failure: extract strides in canonical layout"); 587 } 588 589 // Non-canonical case requires more work. 590 auto stridedExpr = 591 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 592 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 593 offset = AffineExpr(); 594 strides.clear(); 595 return failure(); 596 } 597 598 // Simplify results to allow folding to constants and simple checks. 599 unsigned numDims = m.getNumDims(); 600 unsigned numSymbols = m.getNumSymbols(); 601 offset = simplifyAffineExpr(offset, numDims, numSymbols); 602 for (auto &stride : strides) 603 stride = simplifyAffineExpr(stride, numDims, numSymbols); 604 605 /// In practice, a strided memref must be internally non-aliasing. Test 606 /// against 0 as a proxy. 607 /// TODO: static cases can have more advanced checks. 608 /// TODO: dynamic cases would require a way to compare symbolic 609 /// expressions and would probably need an affine set context propagated 610 /// everywhere. 611 if (llvm::any_of(strides, [](AffineExpr e) { 612 return e == getAffineConstantExpr(0, e.getContext()); 613 })) { 614 offset = AffineExpr(); 615 strides.clear(); 616 return failure(); 617 } 618 619 return success(); 620 } 621 622 LogicalResult mlir::getStridesAndOffset(MemRefType t, 623 SmallVectorImpl<int64_t> &strides, 624 int64_t &offset) { 625 AffineExpr offsetExpr; 626 SmallVector<AffineExpr, 4> strideExprs; 627 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 628 return failure(); 629 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 630 offset = cst.getValue(); 631 else 632 offset = ShapedType::kDynamicStrideOrOffset; 633 for (auto e : strideExprs) { 634 if (auto c = e.dyn_cast<AffineConstantExpr>()) 635 strides.push_back(c.getValue()); 636 else 637 strides.push_back(ShapedType::kDynamicStrideOrOffset); 638 } 639 return success(); 640 } 641 642 //===----------------------------------------------------------------------===// 643 /// TupleType 644 //===----------------------------------------------------------------------===// 645 646 /// Return the elements types for this tuple. 647 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 648 649 /// Accumulate the types contained in this tuple and tuples nested within it. 650 /// Note that this only flattens nested tuples, not any other container type, 651 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 652 /// (i32, tensor<i32>, f32, i64) 653 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 654 for (Type type : getTypes()) { 655 if (auto nestedTuple = type.dyn_cast<TupleType>()) 656 nestedTuple.getFlattenedTypes(types); 657 else 658 types.push_back(type); 659 } 660 } 661 662 /// Return the number of element types. 663 size_t TupleType::size() const { return getImpl()->size(); } 664 665 //===----------------------------------------------------------------------===// 666 // Type Utilities 667 //===----------------------------------------------------------------------===// 668 669 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 670 int64_t offset, 671 MLIRContext *context) { 672 AffineExpr expr; 673 unsigned nSymbols = 0; 674 675 // AffineExpr for offset. 676 // Static case. 677 if (offset != MemRefType::getDynamicStrideOrOffset()) { 678 auto cst = getAffineConstantExpr(offset, context); 679 expr = cst; 680 } else { 681 // Dynamic case, new symbol for the offset. 682 auto sym = getAffineSymbolExpr(nSymbols++, context); 683 expr = sym; 684 } 685 686 // AffineExpr for strides. 687 for (auto en : llvm::enumerate(strides)) { 688 auto dim = en.index(); 689 auto stride = en.value(); 690 assert(stride != 0 && "Invalid stride specification"); 691 auto d = getAffineDimExpr(dim, context); 692 AffineExpr mult; 693 // Static case. 694 if (stride != MemRefType::getDynamicStrideOrOffset()) 695 mult = getAffineConstantExpr(stride, context); 696 else 697 // Dynamic case, new symbol for each new stride. 698 mult = getAffineSymbolExpr(nSymbols++, context); 699 expr = expr + d * mult; 700 } 701 702 return AffineMap::get(strides.size(), nSymbols, expr); 703 } 704 705 /// Return a version of `t` with identity layout if it can be determined 706 /// statically that the layout is the canonical contiguous strided layout. 707 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 708 /// `t` with simplified layout. 709 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 710 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 711 auto affineMaps = t.getAffineMaps(); 712 // Already in canonical form. 713 if (affineMaps.empty()) 714 return t; 715 716 // Can't reduce to canonical identity form, return in canonical form. 717 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 718 return t; 719 720 // Corner-case for 0-D affine maps. 721 auto m = affineMaps[0]; 722 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 723 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 724 if (cst.getValue() == 0) 725 return MemRefType::Builder(t).setAffineMaps({}); 726 return t; 727 } 728 729 // 0-D corner case for empty shape that still have an affine map. Example: 730 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 731 // offset needs to remain, just return t. 732 if (t.getShape().empty()) 733 return t; 734 735 // If the canonical strided layout for the sizes of `t` is equal to the 736 // simplified layout of `t` we can just return an empty layout. Otherwise, 737 // just simplify the existing layout. 738 AffineExpr expr = 739 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 740 auto simplifiedLayoutExpr = 741 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 742 if (expr != simplifiedLayoutExpr) 743 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 744 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 745 return MemRefType::Builder(t).setAffineMaps({}); 746 } 747 748 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 749 ArrayRef<AffineExpr> exprs, 750 MLIRContext *context) { 751 assert(!sizes.empty() && !exprs.empty() && 752 "expected non-empty sizes and exprs"); 753 754 // Size 0 corner case is useful for canonicalizations. 755 if (llvm::is_contained(sizes, 0)) 756 return getAffineConstantExpr(0, context); 757 758 auto maps = AffineMap::inferFromExprList(exprs); 759 assert(!maps.empty() && "Expected one non-empty map"); 760 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 761 762 AffineExpr expr; 763 bool dynamicPoisonBit = false; 764 int64_t runningSize = 1; 765 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 766 int64_t size = std::get<1>(en); 767 // Degenerate case, no size =-> no stride 768 if (size == 0) 769 continue; 770 AffineExpr dimExpr = std::get<0>(en); 771 AffineExpr stride = dynamicPoisonBit 772 ? getAffineSymbolExpr(nSymbols++, context) 773 : getAffineConstantExpr(runningSize, context); 774 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 775 if (size > 0) { 776 runningSize *= size; 777 assert(runningSize > 0 && "integer overflow in size computation"); 778 } else { 779 dynamicPoisonBit = true; 780 } 781 } 782 return simplifyAffineExpr(expr, numDims, nSymbols); 783 } 784 785 /// Return a version of `t` with a layout that has all dynamic offset and 786 /// strides. This is used to erase the static layout. 787 MemRefType mlir::eraseStridedLayout(MemRefType t) { 788 auto val = ShapedType::kDynamicStrideOrOffset; 789 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 790 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 791 } 792 793 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 794 MLIRContext *context) { 795 SmallVector<AffineExpr, 4> exprs; 796 exprs.reserve(sizes.size()); 797 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 798 exprs.push_back(getAffineDimExpr(dim, context)); 799 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 800 } 801 802 /// Return true if the layout for `t` is compatible with strided semantics. 803 bool mlir::isStrided(MemRefType t) { 804 int64_t offset; 805 SmallVector<int64_t, 4> strides; 806 auto res = getStridesAndOffset(t, strides, offset); 807 return succeeded(res); 808 } 809 810 /// Return the layout map in strided linear layout AffineMap form. 811 /// Return null if the layout is not compatible with a strided layout. 812 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 813 int64_t offset; 814 SmallVector<int64_t, 4> strides; 815 if (failed(getStridesAndOffset(t, strides, offset))) 816 return AffineMap(); 817 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 818 } 819