1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 /// Tablegen Type Definitions 25 //===----------------------------------------------------------------------===// 26 27 #define GET_TYPEDEF_CLASSES 28 #include "mlir/IR/BuiltinTypes.cpp.inc" 29 30 //===----------------------------------------------------------------------===// 31 /// ComplexType 32 //===----------------------------------------------------------------------===// 33 34 /// Verify the construction of an integer type. 35 LogicalResult ComplexType::verifyConstructionInvariants(Location loc, 36 Type elementType) { 37 if (!elementType.isIntOrFloat()) 38 return emitError(loc, "invalid element type for complex"); 39 return success(); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // Integer Type 44 //===----------------------------------------------------------------------===// 45 46 // static constexpr must have a definition (until in C++17 and inline variable). 47 constexpr unsigned IntegerType::kMaxWidth; 48 49 /// Verify the construction of an integer type. 50 LogicalResult 51 IntegerType::verifyConstructionInvariants(Location loc, unsigned width, 52 SignednessSemantics signedness) { 53 if (width > IntegerType::kMaxWidth) { 54 return emitError(loc) << "integer bitwidth is limited to " 55 << IntegerType::kMaxWidth << " bits"; 56 } 57 return success(); 58 } 59 60 unsigned IntegerType::getWidth() const { return getImpl()->width; } 61 62 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 63 return getImpl()->signedness; 64 } 65 66 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 67 if (!scale) 68 return IntegerType(); 69 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // Float Type 74 //===----------------------------------------------------------------------===// 75 76 unsigned FloatType::getWidth() { 77 if (isa<Float16Type, BFloat16Type>()) 78 return 16; 79 if (isa<Float32Type>()) 80 return 32; 81 if (isa<Float64Type>()) 82 return 64; 83 llvm_unreachable("unexpected float type"); 84 } 85 86 /// Returns the floating semantics for the given type. 87 const llvm::fltSemantics &FloatType::getFloatSemantics() { 88 if (isa<BFloat16Type>()) 89 return APFloat::BFloat(); 90 if (isa<Float16Type>()) 91 return APFloat::IEEEhalf(); 92 if (isa<Float32Type>()) 93 return APFloat::IEEEsingle(); 94 if (isa<Float64Type>()) 95 return APFloat::IEEEdouble(); 96 llvm_unreachable("non-floating point type used"); 97 } 98 99 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 100 if (!scale) 101 return FloatType(); 102 MLIRContext *ctx = getContext(); 103 if (isF16() || isBF16()) { 104 if (scale == 2) 105 return FloatType::getF32(ctx); 106 if (scale == 4) 107 return FloatType::getF64(ctx); 108 } 109 if (isF32()) 110 if (scale == 2) 111 return FloatType::getF64(ctx); 112 return FloatType(); 113 } 114 115 //===----------------------------------------------------------------------===// 116 // FunctionType 117 //===----------------------------------------------------------------------===// 118 119 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 120 121 ArrayRef<Type> FunctionType::getInputs() const { 122 return getImpl()->getInputs(); 123 } 124 125 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 126 127 ArrayRef<Type> FunctionType::getResults() const { 128 return getImpl()->getResults(); 129 } 130 131 /// Helper to call a callback once on each index in the range 132 /// [0, `totalIndices`), *except* for the indices given in `indices`. 133 /// `indices` is allowed to have duplicates and can be in any order. 134 inline void iterateIndicesExcept(unsigned totalIndices, 135 ArrayRef<unsigned> indices, 136 function_ref<void(unsigned)> callback) { 137 llvm::BitVector skipIndices(totalIndices); 138 for (unsigned i : indices) 139 skipIndices.set(i); 140 141 for (unsigned i = 0; i < totalIndices; ++i) 142 if (!skipIndices.test(i)) 143 callback(i); 144 } 145 146 /// Returns a new function type without the specified arguments and results. 147 FunctionType 148 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 149 ArrayRef<unsigned> resultIndices) { 150 ArrayRef<Type> newInputTypes = getInputs(); 151 SmallVector<Type, 4> newInputTypesBuffer; 152 if (!argIndices.empty()) { 153 unsigned originalNumArgs = getNumInputs(); 154 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 155 newInputTypesBuffer.emplace_back(getInput(i)); 156 }); 157 newInputTypes = newInputTypesBuffer; 158 } 159 160 ArrayRef<Type> newResultTypes = getResults(); 161 SmallVector<Type, 4> newResultTypesBuffer; 162 if (!resultIndices.empty()) { 163 unsigned originalNumResults = getNumResults(); 164 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 165 newResultTypesBuffer.emplace_back(getResult(i)); 166 }); 167 newResultTypes = newResultTypesBuffer; 168 } 169 170 return get(getContext(), newInputTypes, newResultTypes); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // OpaqueType 175 //===----------------------------------------------------------------------===// 176 177 /// Verify the construction of an opaque type. 178 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, 179 Identifier dialect, 180 StringRef typeData) { 181 if (!Dialect::isValidNamespace(dialect.strref())) 182 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 183 return success(); 184 } 185 186 //===----------------------------------------------------------------------===// 187 // ShapedType 188 //===----------------------------------------------------------------------===// 189 constexpr int64_t ShapedType::kDynamicSize; 190 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 191 192 Type ShapedType::getElementType() const { 193 return static_cast<ImplType *>(impl)->elementType; 194 } 195 196 unsigned ShapedType::getElementTypeBitWidth() const { 197 return getElementType().getIntOrFloatBitWidth(); 198 } 199 200 int64_t ShapedType::getNumElements() const { 201 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 202 auto shape = getShape(); 203 int64_t num = 1; 204 for (auto dim : shape) 205 num *= dim; 206 return num; 207 } 208 209 int64_t ShapedType::getRank() const { return getShape().size(); } 210 211 bool ShapedType::hasRank() const { 212 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 213 } 214 215 int64_t ShapedType::getDimSize(unsigned idx) const { 216 assert(idx < getRank() && "invalid index for shaped type"); 217 return getShape()[idx]; 218 } 219 220 bool ShapedType::isDynamicDim(unsigned idx) const { 221 assert(idx < getRank() && "invalid index for shaped type"); 222 return isDynamic(getShape()[idx]); 223 } 224 225 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 226 assert(index < getRank() && "invalid index"); 227 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 228 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 229 } 230 231 /// Get the number of bits require to store a value of the given shaped type. 232 /// Compute the value recursively since tensors are allowed to have vectors as 233 /// elements. 234 int64_t ShapedType::getSizeInBits() const { 235 assert(hasStaticShape() && 236 "cannot get the bit size of an aggregate with a dynamic shape"); 237 238 auto elementType = getElementType(); 239 if (elementType.isIntOrFloat()) 240 return elementType.getIntOrFloatBitWidth() * getNumElements(); 241 242 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 243 elementType = complexType.getElementType(); 244 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 245 } 246 247 // Tensors can have vectors and other tensors as elements, other shaped types 248 // cannot. 249 assert(isa<TensorType>() && "unsupported element type"); 250 assert((elementType.isa<VectorType, TensorType>()) && 251 "unsupported tensor element type"); 252 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 253 } 254 255 ArrayRef<int64_t> ShapedType::getShape() const { 256 if (auto vectorType = dyn_cast<VectorType>()) 257 return vectorType.getShape(); 258 if (auto tensorType = dyn_cast<RankedTensorType>()) 259 return tensorType.getShape(); 260 return cast<MemRefType>().getShape(); 261 } 262 263 int64_t ShapedType::getNumDynamicDims() const { 264 return llvm::count_if(getShape(), isDynamic); 265 } 266 267 bool ShapedType::hasStaticShape() const { 268 return hasRank() && llvm::none_of(getShape(), isDynamic); 269 } 270 271 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 272 return hasStaticShape() && getShape() == shape; 273 } 274 275 //===----------------------------------------------------------------------===// 276 // VectorType 277 //===----------------------------------------------------------------------===// 278 279 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { 280 return Base::get(elementType.getContext(), shape, elementType); 281 } 282 283 VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape, 284 Type elementType) { 285 return Base::getChecked(location, shape, elementType); 286 } 287 288 LogicalResult VectorType::verifyConstructionInvariants(Location loc, 289 ArrayRef<int64_t> shape, 290 Type elementType) { 291 if (shape.empty()) 292 return emitError(loc, "vector types must have at least one dimension"); 293 294 if (!isValidElementType(elementType)) 295 return emitError(loc, "vector elements must be int or float type"); 296 297 if (any_of(shape, [](int64_t i) { return i <= 0; })) 298 return emitError(loc, "vector types must have positive constant sizes"); 299 300 return success(); 301 } 302 303 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } 304 305 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 306 if (!scale) 307 return VectorType(); 308 if (auto et = getElementType().dyn_cast<IntegerType>()) 309 if (auto scaledEt = et.scaleElementBitwidth(scale)) 310 return VectorType::get(getShape(), scaledEt); 311 if (auto et = getElementType().dyn_cast<FloatType>()) 312 if (auto scaledEt = et.scaleElementBitwidth(scale)) 313 return VectorType::get(getShape(), scaledEt); 314 return VectorType(); 315 } 316 317 //===----------------------------------------------------------------------===// 318 // TensorType 319 //===----------------------------------------------------------------------===// 320 321 // Check if "elementType" can be an element type of a tensor. Emit errors if 322 // location is not nullptr. Returns failure if check failed. 323 static LogicalResult checkTensorElementType(Location location, 324 Type elementType) { 325 if (!TensorType::isValidElementType(elementType)) 326 return emitError(location, "invalid tensor element type: ") << elementType; 327 return success(); 328 } 329 330 /// Return true if the specified element type is ok in a tensor. 331 bool TensorType::isValidElementType(Type type) { 332 // Note: Non standard/builtin types are allowed to exist within tensor 333 // types. Dialects are expected to verify that tensor types have a valid 334 // element type within that dialect. 335 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 336 IndexType>() || 337 !type.getDialect().getNamespace().empty(); 338 } 339 340 //===----------------------------------------------------------------------===// 341 // RankedTensorType 342 //===----------------------------------------------------------------------===// 343 344 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, 345 Type elementType) { 346 return Base::get(elementType.getContext(), shape, elementType); 347 } 348 349 RankedTensorType RankedTensorType::getChecked(Location location, 350 ArrayRef<int64_t> shape, 351 Type elementType) { 352 return Base::getChecked(location, shape, elementType); 353 } 354 355 LogicalResult RankedTensorType::verifyConstructionInvariants( 356 Location loc, ArrayRef<int64_t> shape, Type elementType) { 357 for (int64_t s : shape) { 358 if (s < -1) 359 return emitError(loc, "invalid tensor dimension size"); 360 } 361 return checkTensorElementType(loc, elementType); 362 } 363 364 ArrayRef<int64_t> RankedTensorType::getShape() const { 365 return getImpl()->getShape(); 366 } 367 368 //===----------------------------------------------------------------------===// 369 // UnrankedTensorType 370 //===----------------------------------------------------------------------===// 371 372 UnrankedTensorType UnrankedTensorType::get(Type elementType) { 373 return Base::get(elementType.getContext(), elementType); 374 } 375 376 UnrankedTensorType UnrankedTensorType::getChecked(Location location, 377 Type elementType) { 378 return Base::getChecked(location, elementType); 379 } 380 381 LogicalResult 382 UnrankedTensorType::verifyConstructionInvariants(Location loc, 383 Type elementType) { 384 return checkTensorElementType(loc, elementType); 385 } 386 387 //===----------------------------------------------------------------------===// 388 // BaseMemRefType 389 //===----------------------------------------------------------------------===// 390 391 unsigned BaseMemRefType::getMemorySpace() const { 392 return static_cast<ImplType *>(impl)->memorySpace; 393 } 394 395 //===----------------------------------------------------------------------===// 396 // MemRefType 397 //===----------------------------------------------------------------------===// 398 399 /// Get or create a new MemRefType based on shape, element type, affine 400 /// map composition, and memory space. Assumes the arguments define a 401 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 402 /// construction failures. 403 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 404 ArrayRef<AffineMap> affineMapComposition, 405 unsigned memorySpace) { 406 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, 407 /*location=*/llvm::None); 408 assert(result && "Failed to construct instance of MemRefType."); 409 return result; 410 } 411 412 /// Get or create a new MemRefType based on shape, element type, affine 413 /// map composition, and memory space declared at the given location. 414 /// If the location is unknown, the last argument should be an instance of 415 /// UnknownLoc. If the MemRefType defined by the arguments would be 416 /// ill-formed, emits errors (to the handler registered with the context or to 417 /// the error stream) and returns nullptr. 418 MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape, 419 Type elementType, 420 ArrayRef<AffineMap> affineMapComposition, 421 unsigned memorySpace) { 422 return getImpl(shape, elementType, affineMapComposition, memorySpace, 423 location); 424 } 425 426 /// Get or create a new MemRefType defined by the arguments. If the resulting 427 /// type would be ill-formed, return nullptr. If the location is provided, 428 /// emit detailed error messages. To emit errors when the location is unknown, 429 /// pass in an instance of UnknownLoc. 430 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, 431 ArrayRef<AffineMap> affineMapComposition, 432 unsigned memorySpace, 433 Optional<Location> location) { 434 auto *context = elementType.getContext(); 435 436 if (!BaseMemRefType::isValidElementType(elementType)) 437 return emitOptionalError(location, "invalid memref element type"), 438 MemRefType(); 439 440 for (int64_t s : shape) { 441 // Negative sizes are not allowed except for `-1` that means dynamic size. 442 if (s < -1) 443 return emitOptionalError(location, "invalid memref size"), MemRefType(); 444 } 445 446 // Check that the structure of the composition is valid, i.e. that each 447 // subsequent affine map has as many inputs as the previous map has results. 448 // Take the dimensionality of the MemRef for the first map. 449 auto dim = shape.size(); 450 unsigned i = 0; 451 for (const auto &affineMap : affineMapComposition) { 452 if (affineMap.getNumDims() != dim) { 453 if (location) 454 emitError(*location) 455 << "memref affine map dimension mismatch between " 456 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 457 << " and affine map" << i + 1 << ": " << dim 458 << " != " << affineMap.getNumDims(); 459 return nullptr; 460 } 461 462 dim = affineMap.getNumResults(); 463 ++i; 464 } 465 466 // Drop identity maps from the composition. 467 // This may lead to the composition becoming empty, which is interpreted as an 468 // implicit identity. 469 SmallVector<AffineMap, 2> cleanedAffineMapComposition; 470 for (const auto &map : affineMapComposition) { 471 if (map.isIdentity()) 472 continue; 473 cleanedAffineMapComposition.push_back(map); 474 } 475 476 return Base::get(context, shape, elementType, cleanedAffineMapComposition, 477 memorySpace); 478 } 479 480 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } 481 482 ArrayRef<AffineMap> MemRefType::getAffineMaps() const { 483 return getImpl()->getAffineMaps(); 484 } 485 486 //===----------------------------------------------------------------------===// 487 // UnrankedMemRefType 488 //===----------------------------------------------------------------------===// 489 490 UnrankedMemRefType UnrankedMemRefType::get(Type elementType, 491 unsigned memorySpace) { 492 return Base::get(elementType.getContext(), elementType, memorySpace); 493 } 494 495 UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, 496 Type elementType, 497 unsigned memorySpace) { 498 return Base::getChecked(location, elementType, memorySpace); 499 } 500 501 LogicalResult 502 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, 503 unsigned memorySpace) { 504 if (!BaseMemRefType::isValidElementType(elementType)) 505 return emitError(loc, "invalid memref element type"); 506 return success(); 507 } 508 509 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 510 // i.e. single term). Accumulate the AffineExpr into the existing one. 511 static void extractStridesFromTerm(AffineExpr e, 512 AffineExpr multiplicativeFactor, 513 MutableArrayRef<AffineExpr> strides, 514 AffineExpr &offset) { 515 if (auto dim = e.dyn_cast<AffineDimExpr>()) 516 strides[dim.getPosition()] = 517 strides[dim.getPosition()] + multiplicativeFactor; 518 else 519 offset = offset + e * multiplicativeFactor; 520 } 521 522 /// Takes a single AffineExpr `e` and populates the `strides` array with the 523 /// strides expressions for each dim position. 524 /// The convention is that the strides for dimensions d0, .. dn appear in 525 /// order to make indexing intuitive into the result. 526 static LogicalResult extractStrides(AffineExpr e, 527 AffineExpr multiplicativeFactor, 528 MutableArrayRef<AffineExpr> strides, 529 AffineExpr &offset) { 530 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 531 if (!bin) { 532 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 533 return success(); 534 } 535 536 if (bin.getKind() == AffineExprKind::CeilDiv || 537 bin.getKind() == AffineExprKind::FloorDiv || 538 bin.getKind() == AffineExprKind::Mod) 539 return failure(); 540 541 if (bin.getKind() == AffineExprKind::Mul) { 542 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 543 if (dim) { 544 strides[dim.getPosition()] = 545 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 546 return success(); 547 } 548 // LHS and RHS may both contain complex expressions of dims. Try one path 549 // and if it fails try the other. This is guaranteed to succeed because 550 // only one path may have a `dim`, otherwise this is not an AffineExpr in 551 // the first place. 552 if (bin.getLHS().isSymbolicOrConstant()) 553 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 554 strides, offset); 555 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 556 strides, offset); 557 } 558 559 if (bin.getKind() == AffineExprKind::Add) { 560 auto res1 = 561 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 562 auto res2 = 563 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 564 return success(succeeded(res1) && succeeded(res2)); 565 } 566 567 llvm_unreachable("unexpected binary operation"); 568 } 569 570 LogicalResult mlir::getStridesAndOffset(MemRefType t, 571 SmallVectorImpl<AffineExpr> &strides, 572 AffineExpr &offset) { 573 auto affineMaps = t.getAffineMaps(); 574 // For now strides are only computed on a single affine map with a single 575 // result (i.e. the closed subset of linearization maps that are compatible 576 // with striding semantics). 577 // TODO: support more forms on a per-need basis. 578 if (affineMaps.size() > 1) 579 return failure(); 580 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 581 return failure(); 582 583 auto zero = getAffineConstantExpr(0, t.getContext()); 584 auto one = getAffineConstantExpr(1, t.getContext()); 585 offset = zero; 586 strides.assign(t.getRank(), zero); 587 588 AffineMap m; 589 if (!affineMaps.empty()) { 590 m = affineMaps.front(); 591 assert(!m.isIdentity() && "unexpected identity map"); 592 } 593 594 // Canonical case for empty map. 595 if (!m) { 596 // 0-D corner case, offset is already 0. 597 if (t.getRank() == 0) 598 return success(); 599 auto stridedExpr = 600 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 601 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 602 return success(); 603 assert(false && "unexpected failure: extract strides in canonical layout"); 604 } 605 606 // Non-canonical case requires more work. 607 auto stridedExpr = 608 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 609 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 610 offset = AffineExpr(); 611 strides.clear(); 612 return failure(); 613 } 614 615 // Simplify results to allow folding to constants and simple checks. 616 unsigned numDims = m.getNumDims(); 617 unsigned numSymbols = m.getNumSymbols(); 618 offset = simplifyAffineExpr(offset, numDims, numSymbols); 619 for (auto &stride : strides) 620 stride = simplifyAffineExpr(stride, numDims, numSymbols); 621 622 /// In practice, a strided memref must be internally non-aliasing. Test 623 /// against 0 as a proxy. 624 /// TODO: static cases can have more advanced checks. 625 /// TODO: dynamic cases would require a way to compare symbolic 626 /// expressions and would probably need an affine set context propagated 627 /// everywhere. 628 if (llvm::any_of(strides, [](AffineExpr e) { 629 return e == getAffineConstantExpr(0, e.getContext()); 630 })) { 631 offset = AffineExpr(); 632 strides.clear(); 633 return failure(); 634 } 635 636 return success(); 637 } 638 639 LogicalResult mlir::getStridesAndOffset(MemRefType t, 640 SmallVectorImpl<int64_t> &strides, 641 int64_t &offset) { 642 AffineExpr offsetExpr; 643 SmallVector<AffineExpr, 4> strideExprs; 644 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 645 return failure(); 646 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 647 offset = cst.getValue(); 648 else 649 offset = ShapedType::kDynamicStrideOrOffset; 650 for (auto e : strideExprs) { 651 if (auto c = e.dyn_cast<AffineConstantExpr>()) 652 strides.push_back(c.getValue()); 653 else 654 strides.push_back(ShapedType::kDynamicStrideOrOffset); 655 } 656 return success(); 657 } 658 659 //===----------------------------------------------------------------------===// 660 /// TupleType 661 //===----------------------------------------------------------------------===// 662 663 /// Return the elements types for this tuple. 664 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 665 666 /// Accumulate the types contained in this tuple and tuples nested within it. 667 /// Note that this only flattens nested tuples, not any other container type, 668 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 669 /// (i32, tensor<i32>, f32, i64) 670 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 671 for (Type type : getTypes()) { 672 if (auto nestedTuple = type.dyn_cast<TupleType>()) 673 nestedTuple.getFlattenedTypes(types); 674 else 675 types.push_back(type); 676 } 677 } 678 679 /// Return the number of element types. 680 size_t TupleType::size() const { return getImpl()->size(); } 681 682 //===----------------------------------------------------------------------===// 683 // Type Utilities 684 //===----------------------------------------------------------------------===// 685 686 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 687 int64_t offset, 688 MLIRContext *context) { 689 AffineExpr expr; 690 unsigned nSymbols = 0; 691 692 // AffineExpr for offset. 693 // Static case. 694 if (offset != MemRefType::getDynamicStrideOrOffset()) { 695 auto cst = getAffineConstantExpr(offset, context); 696 expr = cst; 697 } else { 698 // Dynamic case, new symbol for the offset. 699 auto sym = getAffineSymbolExpr(nSymbols++, context); 700 expr = sym; 701 } 702 703 // AffineExpr for strides. 704 for (auto en : llvm::enumerate(strides)) { 705 auto dim = en.index(); 706 auto stride = en.value(); 707 assert(stride != 0 && "Invalid stride specification"); 708 auto d = getAffineDimExpr(dim, context); 709 AffineExpr mult; 710 // Static case. 711 if (stride != MemRefType::getDynamicStrideOrOffset()) 712 mult = getAffineConstantExpr(stride, context); 713 else 714 // Dynamic case, new symbol for each new stride. 715 mult = getAffineSymbolExpr(nSymbols++, context); 716 expr = expr + d * mult; 717 } 718 719 return AffineMap::get(strides.size(), nSymbols, expr); 720 } 721 722 /// Return a version of `t` with identity layout if it can be determined 723 /// statically that the layout is the canonical contiguous strided layout. 724 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 725 /// `t` with simplified layout. 726 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 727 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 728 auto affineMaps = t.getAffineMaps(); 729 // Already in canonical form. 730 if (affineMaps.empty()) 731 return t; 732 733 // Can't reduce to canonical identity form, return in canonical form. 734 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 735 return t; 736 737 // If the canonical strided layout for the sizes of `t` is equal to the 738 // simplified layout of `t` we can just return an empty layout. Otherwise, 739 // just simplify the existing layout. 740 AffineExpr expr = 741 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 742 auto m = affineMaps[0]; 743 auto simplifiedLayoutExpr = 744 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 745 if (expr != simplifiedLayoutExpr) 746 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 747 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 748 return MemRefType::Builder(t).setAffineMaps({}); 749 } 750 751 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 752 ArrayRef<AffineExpr> exprs, 753 MLIRContext *context) { 754 // Size 0 corner case is useful for canonicalizations. 755 if (llvm::is_contained(sizes, 0)) 756 return getAffineConstantExpr(0, context); 757 758 auto maps = AffineMap::inferFromExprList(exprs); 759 assert(!maps.empty() && "Expected one non-empty map"); 760 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 761 762 AffineExpr expr; 763 bool dynamicPoisonBit = false; 764 int64_t runningSize = 1; 765 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 766 int64_t size = std::get<1>(en); 767 // Degenerate case, no size =-> no stride 768 if (size == 0) 769 continue; 770 AffineExpr dimExpr = std::get<0>(en); 771 AffineExpr stride = dynamicPoisonBit 772 ? getAffineSymbolExpr(nSymbols++, context) 773 : getAffineConstantExpr(runningSize, context); 774 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 775 if (size > 0) 776 runningSize *= size; 777 else 778 dynamicPoisonBit = true; 779 } 780 return simplifyAffineExpr(expr, numDims, nSymbols); 781 } 782 783 /// Return a version of `t` with a layout that has all dynamic offset and 784 /// strides. This is used to erase the static layout. 785 MemRefType mlir::eraseStridedLayout(MemRefType t) { 786 auto val = ShapedType::kDynamicStrideOrOffset; 787 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 788 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 789 } 790 791 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 792 MLIRContext *context) { 793 SmallVector<AffineExpr, 4> exprs; 794 exprs.reserve(sizes.size()); 795 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 796 exprs.push_back(getAffineDimExpr(dim, context)); 797 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 798 } 799 800 /// Return true if the layout for `t` is compatible with strided semantics. 801 bool mlir::isStrided(MemRefType t) { 802 int64_t offset; 803 SmallVector<int64_t, 4> stridesAndOffset; 804 auto res = getStridesAndOffset(t, stridesAndOffset, offset); 805 return succeeded(res); 806 } 807