1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 /// Tablegen Type Definitions 25 //===----------------------------------------------------------------------===// 26 27 #define GET_TYPEDEF_CLASSES 28 #include "mlir/IR/BuiltinTypes.cpp.inc" 29 30 //===----------------------------------------------------------------------===// 31 /// ComplexType 32 //===----------------------------------------------------------------------===// 33 34 /// Verify the construction of an integer type. 35 LogicalResult ComplexType::verifyConstructionInvariants(Location loc, 36 Type elementType) { 37 if (!elementType.isIntOrFloat()) 38 return emitError(loc, "invalid element type for complex"); 39 return success(); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // Integer Type 44 //===----------------------------------------------------------------------===// 45 46 // static constexpr must have a definition (until in C++17 and inline variable). 47 constexpr unsigned IntegerType::kMaxWidth; 48 49 /// Verify the construction of an integer type. 50 LogicalResult 51 IntegerType::verifyConstructionInvariants(Location loc, unsigned width, 52 SignednessSemantics signedness) { 53 if (width > IntegerType::kMaxWidth) { 54 return emitError(loc) << "integer bitwidth is limited to " 55 << IntegerType::kMaxWidth << " bits"; 56 } 57 return success(); 58 } 59 60 unsigned IntegerType::getWidth() const { return getImpl()->width; } 61 62 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 63 return getImpl()->signedness; 64 } 65 66 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 67 if (!scale) 68 return IntegerType(); 69 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // Float Type 74 //===----------------------------------------------------------------------===// 75 76 unsigned FloatType::getWidth() { 77 if (isa<Float16Type, BFloat16Type>()) 78 return 16; 79 if (isa<Float32Type>()) 80 return 32; 81 if (isa<Float64Type>()) 82 return 64; 83 if (isa<Float80Type>()) 84 return 80; 85 if (isa<Float128Type>()) 86 return 128; 87 llvm_unreachable("unexpected float type"); 88 } 89 90 /// Returns the floating semantics for the given type. 91 const llvm::fltSemantics &FloatType::getFloatSemantics() { 92 if (isa<BFloat16Type>()) 93 return APFloat::BFloat(); 94 if (isa<Float16Type>()) 95 return APFloat::IEEEhalf(); 96 if (isa<Float32Type>()) 97 return APFloat::IEEEsingle(); 98 if (isa<Float64Type>()) 99 return APFloat::IEEEdouble(); 100 if (isa<Float80Type>()) 101 return APFloat::x87DoubleExtended(); 102 if (isa<Float128Type>()) 103 return APFloat::IEEEquad(); 104 llvm_unreachable("non-floating point type used"); 105 } 106 107 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 108 if (!scale) 109 return FloatType(); 110 MLIRContext *ctx = getContext(); 111 if (isF16() || isBF16()) { 112 if (scale == 2) 113 return FloatType::getF32(ctx); 114 if (scale == 4) 115 return FloatType::getF64(ctx); 116 } 117 if (isF32()) 118 if (scale == 2) 119 return FloatType::getF64(ctx); 120 return FloatType(); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // FunctionType 125 //===----------------------------------------------------------------------===// 126 127 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 128 129 ArrayRef<Type> FunctionType::getInputs() const { 130 return getImpl()->getInputs(); 131 } 132 133 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 134 135 ArrayRef<Type> FunctionType::getResults() const { 136 return getImpl()->getResults(); 137 } 138 139 /// Helper to call a callback once on each index in the range 140 /// [0, `totalIndices`), *except* for the indices given in `indices`. 141 /// `indices` is allowed to have duplicates and can be in any order. 142 inline void iterateIndicesExcept(unsigned totalIndices, 143 ArrayRef<unsigned> indices, 144 function_ref<void(unsigned)> callback) { 145 llvm::BitVector skipIndices(totalIndices); 146 for (unsigned i : indices) 147 skipIndices.set(i); 148 149 for (unsigned i = 0; i < totalIndices; ++i) 150 if (!skipIndices.test(i)) 151 callback(i); 152 } 153 154 /// Returns a new function type without the specified arguments and results. 155 FunctionType 156 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 157 ArrayRef<unsigned> resultIndices) { 158 ArrayRef<Type> newInputTypes = getInputs(); 159 SmallVector<Type, 4> newInputTypesBuffer; 160 if (!argIndices.empty()) { 161 unsigned originalNumArgs = getNumInputs(); 162 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 163 newInputTypesBuffer.emplace_back(getInput(i)); 164 }); 165 newInputTypes = newInputTypesBuffer; 166 } 167 168 ArrayRef<Type> newResultTypes = getResults(); 169 SmallVector<Type, 4> newResultTypesBuffer; 170 if (!resultIndices.empty()) { 171 unsigned originalNumResults = getNumResults(); 172 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 173 newResultTypesBuffer.emplace_back(getResult(i)); 174 }); 175 newResultTypes = newResultTypesBuffer; 176 } 177 178 return get(getContext(), newInputTypes, newResultTypes); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // OpaqueType 183 //===----------------------------------------------------------------------===// 184 185 /// Verify the construction of an opaque type. 186 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, 187 Identifier dialect, 188 StringRef typeData) { 189 if (!Dialect::isValidNamespace(dialect.strref())) 190 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 191 return success(); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // ShapedType 196 //===----------------------------------------------------------------------===// 197 constexpr int64_t ShapedType::kDynamicSize; 198 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 199 200 Type ShapedType::getElementType() const { 201 return static_cast<ImplType *>(impl)->elementType; 202 } 203 204 unsigned ShapedType::getElementTypeBitWidth() const { 205 return getElementType().getIntOrFloatBitWidth(); 206 } 207 208 int64_t ShapedType::getNumElements() const { 209 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 210 auto shape = getShape(); 211 int64_t num = 1; 212 for (auto dim : shape) 213 num *= dim; 214 return num; 215 } 216 217 int64_t ShapedType::getRank() const { return getShape().size(); } 218 219 bool ShapedType::hasRank() const { 220 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 221 } 222 223 int64_t ShapedType::getDimSize(unsigned idx) const { 224 assert(idx < getRank() && "invalid index for shaped type"); 225 return getShape()[idx]; 226 } 227 228 bool ShapedType::isDynamicDim(unsigned idx) const { 229 assert(idx < getRank() && "invalid index for shaped type"); 230 return isDynamic(getShape()[idx]); 231 } 232 233 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 234 assert(index < getRank() && "invalid index"); 235 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 236 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 237 } 238 239 /// Get the number of bits require to store a value of the given shaped type. 240 /// Compute the value recursively since tensors are allowed to have vectors as 241 /// elements. 242 int64_t ShapedType::getSizeInBits() const { 243 assert(hasStaticShape() && 244 "cannot get the bit size of an aggregate with a dynamic shape"); 245 246 auto elementType = getElementType(); 247 if (elementType.isIntOrFloat()) 248 return elementType.getIntOrFloatBitWidth() * getNumElements(); 249 250 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 251 elementType = complexType.getElementType(); 252 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 253 } 254 255 // Tensors can have vectors and other tensors as elements, other shaped types 256 // cannot. 257 assert(isa<TensorType>() && "unsupported element type"); 258 assert((elementType.isa<VectorType, TensorType>()) && 259 "unsupported tensor element type"); 260 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 261 } 262 263 ArrayRef<int64_t> ShapedType::getShape() const { 264 if (auto vectorType = dyn_cast<VectorType>()) 265 return vectorType.getShape(); 266 if (auto tensorType = dyn_cast<RankedTensorType>()) 267 return tensorType.getShape(); 268 return cast<MemRefType>().getShape(); 269 } 270 271 int64_t ShapedType::getNumDynamicDims() const { 272 return llvm::count_if(getShape(), isDynamic); 273 } 274 275 bool ShapedType::hasStaticShape() const { 276 return hasRank() && llvm::none_of(getShape(), isDynamic); 277 } 278 279 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 280 return hasStaticShape() && getShape() == shape; 281 } 282 283 //===----------------------------------------------------------------------===// 284 // VectorType 285 //===----------------------------------------------------------------------===// 286 287 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { 288 return Base::get(elementType.getContext(), shape, elementType); 289 } 290 291 VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape, 292 Type elementType) { 293 return Base::getChecked(location, shape, elementType); 294 } 295 296 LogicalResult VectorType::verifyConstructionInvariants(Location loc, 297 ArrayRef<int64_t> shape, 298 Type elementType) { 299 if (shape.empty()) 300 return emitError(loc, "vector types must have at least one dimension"); 301 302 if (!isValidElementType(elementType)) 303 return emitError(loc, "vector elements must be int or float type"); 304 305 if (any_of(shape, [](int64_t i) { return i <= 0; })) 306 return emitError(loc, "vector types must have positive constant sizes"); 307 308 return success(); 309 } 310 311 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } 312 313 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 314 if (!scale) 315 return VectorType(); 316 if (auto et = getElementType().dyn_cast<IntegerType>()) 317 if (auto scaledEt = et.scaleElementBitwidth(scale)) 318 return VectorType::get(getShape(), scaledEt); 319 if (auto et = getElementType().dyn_cast<FloatType>()) 320 if (auto scaledEt = et.scaleElementBitwidth(scale)) 321 return VectorType::get(getShape(), scaledEt); 322 return VectorType(); 323 } 324 325 //===----------------------------------------------------------------------===// 326 // TensorType 327 //===----------------------------------------------------------------------===// 328 329 // Check if "elementType" can be an element type of a tensor. Emit errors if 330 // location is not nullptr. Returns failure if check failed. 331 static LogicalResult checkTensorElementType(Location location, 332 Type elementType) { 333 if (!TensorType::isValidElementType(elementType)) 334 return emitError(location, "invalid tensor element type: ") << elementType; 335 return success(); 336 } 337 338 /// Return true if the specified element type is ok in a tensor. 339 bool TensorType::isValidElementType(Type type) { 340 // Note: Non standard/builtin types are allowed to exist within tensor 341 // types. Dialects are expected to verify that tensor types have a valid 342 // element type within that dialect. 343 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 344 IndexType>() || 345 !type.getDialect().getNamespace().empty(); 346 } 347 348 //===----------------------------------------------------------------------===// 349 // RankedTensorType 350 //===----------------------------------------------------------------------===// 351 352 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, 353 Type elementType) { 354 return Base::get(elementType.getContext(), shape, elementType); 355 } 356 357 RankedTensorType RankedTensorType::getChecked(Location location, 358 ArrayRef<int64_t> shape, 359 Type elementType) { 360 return Base::getChecked(location, shape, elementType); 361 } 362 363 LogicalResult RankedTensorType::verifyConstructionInvariants( 364 Location loc, ArrayRef<int64_t> shape, Type elementType) { 365 for (int64_t s : shape) { 366 if (s < -1) 367 return emitError(loc, "invalid tensor dimension size"); 368 } 369 return checkTensorElementType(loc, elementType); 370 } 371 372 ArrayRef<int64_t> RankedTensorType::getShape() const { 373 return getImpl()->getShape(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // UnrankedTensorType 378 //===----------------------------------------------------------------------===// 379 380 UnrankedTensorType UnrankedTensorType::get(Type elementType) { 381 return Base::get(elementType.getContext(), elementType); 382 } 383 384 UnrankedTensorType UnrankedTensorType::getChecked(Location location, 385 Type elementType) { 386 return Base::getChecked(location, elementType); 387 } 388 389 LogicalResult 390 UnrankedTensorType::verifyConstructionInvariants(Location loc, 391 Type elementType) { 392 return checkTensorElementType(loc, elementType); 393 } 394 395 //===----------------------------------------------------------------------===// 396 // BaseMemRefType 397 //===----------------------------------------------------------------------===// 398 399 unsigned BaseMemRefType::getMemorySpace() const { 400 return static_cast<ImplType *>(impl)->memorySpace; 401 } 402 403 //===----------------------------------------------------------------------===// 404 // MemRefType 405 //===----------------------------------------------------------------------===// 406 407 /// Get or create a new MemRefType based on shape, element type, affine 408 /// map composition, and memory space. Assumes the arguments define a 409 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 410 /// construction failures. 411 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 412 ArrayRef<AffineMap> affineMapComposition, 413 unsigned memorySpace) { 414 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, 415 /*location=*/llvm::None); 416 assert(result && "Failed to construct instance of MemRefType."); 417 return result; 418 } 419 420 /// Get or create a new MemRefType based on shape, element type, affine 421 /// map composition, and memory space declared at the given location. 422 /// If the location is unknown, the last argument should be an instance of 423 /// UnknownLoc. If the MemRefType defined by the arguments would be 424 /// ill-formed, emits errors (to the handler registered with the context or to 425 /// the error stream) and returns nullptr. 426 MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape, 427 Type elementType, 428 ArrayRef<AffineMap> affineMapComposition, 429 unsigned memorySpace) { 430 return getImpl(shape, elementType, affineMapComposition, memorySpace, 431 location); 432 } 433 434 /// Get or create a new MemRefType defined by the arguments. If the resulting 435 /// type would be ill-formed, return nullptr. If the location is provided, 436 /// emit detailed error messages. To emit errors when the location is unknown, 437 /// pass in an instance of UnknownLoc. 438 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, 439 ArrayRef<AffineMap> affineMapComposition, 440 unsigned memorySpace, 441 Optional<Location> location) { 442 auto *context = elementType.getContext(); 443 444 if (!BaseMemRefType::isValidElementType(elementType)) 445 return emitOptionalError(location, "invalid memref element type"), 446 MemRefType(); 447 448 for (int64_t s : shape) { 449 // Negative sizes are not allowed except for `-1` that means dynamic size. 450 if (s < -1) 451 return emitOptionalError(location, "invalid memref size"), MemRefType(); 452 } 453 454 // Check that the structure of the composition is valid, i.e. that each 455 // subsequent affine map has as many inputs as the previous map has results. 456 // Take the dimensionality of the MemRef for the first map. 457 auto dim = shape.size(); 458 unsigned i = 0; 459 for (const auto &affineMap : affineMapComposition) { 460 if (affineMap.getNumDims() != dim) { 461 if (location) 462 emitError(*location) 463 << "memref affine map dimension mismatch between " 464 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 465 << " and affine map" << i + 1 << ": " << dim 466 << " != " << affineMap.getNumDims(); 467 return nullptr; 468 } 469 470 dim = affineMap.getNumResults(); 471 ++i; 472 } 473 474 // Drop identity maps from the composition. 475 // This may lead to the composition becoming empty, which is interpreted as an 476 // implicit identity. 477 SmallVector<AffineMap, 2> cleanedAffineMapComposition; 478 for (const auto &map : affineMapComposition) { 479 if (map.isIdentity()) 480 continue; 481 cleanedAffineMapComposition.push_back(map); 482 } 483 484 return Base::get(context, shape, elementType, cleanedAffineMapComposition, 485 memorySpace); 486 } 487 488 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } 489 490 ArrayRef<AffineMap> MemRefType::getAffineMaps() const { 491 return getImpl()->getAffineMaps(); 492 } 493 494 //===----------------------------------------------------------------------===// 495 // UnrankedMemRefType 496 //===----------------------------------------------------------------------===// 497 498 UnrankedMemRefType UnrankedMemRefType::get(Type elementType, 499 unsigned memorySpace) { 500 return Base::get(elementType.getContext(), elementType, memorySpace); 501 } 502 503 UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, 504 Type elementType, 505 unsigned memorySpace) { 506 return Base::getChecked(location, elementType, memorySpace); 507 } 508 509 LogicalResult 510 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, 511 unsigned memorySpace) { 512 if (!BaseMemRefType::isValidElementType(elementType)) 513 return emitError(loc, "invalid memref element type"); 514 return success(); 515 } 516 517 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 518 // i.e. single term). Accumulate the AffineExpr into the existing one. 519 static void extractStridesFromTerm(AffineExpr e, 520 AffineExpr multiplicativeFactor, 521 MutableArrayRef<AffineExpr> strides, 522 AffineExpr &offset) { 523 if (auto dim = e.dyn_cast<AffineDimExpr>()) 524 strides[dim.getPosition()] = 525 strides[dim.getPosition()] + multiplicativeFactor; 526 else 527 offset = offset + e * multiplicativeFactor; 528 } 529 530 /// Takes a single AffineExpr `e` and populates the `strides` array with the 531 /// strides expressions for each dim position. 532 /// The convention is that the strides for dimensions d0, .. dn appear in 533 /// order to make indexing intuitive into the result. 534 static LogicalResult extractStrides(AffineExpr e, 535 AffineExpr multiplicativeFactor, 536 MutableArrayRef<AffineExpr> strides, 537 AffineExpr &offset) { 538 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 539 if (!bin) { 540 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 541 return success(); 542 } 543 544 if (bin.getKind() == AffineExprKind::CeilDiv || 545 bin.getKind() == AffineExprKind::FloorDiv || 546 bin.getKind() == AffineExprKind::Mod) 547 return failure(); 548 549 if (bin.getKind() == AffineExprKind::Mul) { 550 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 551 if (dim) { 552 strides[dim.getPosition()] = 553 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 554 return success(); 555 } 556 // LHS and RHS may both contain complex expressions of dims. Try one path 557 // and if it fails try the other. This is guaranteed to succeed because 558 // only one path may have a `dim`, otherwise this is not an AffineExpr in 559 // the first place. 560 if (bin.getLHS().isSymbolicOrConstant()) 561 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 562 strides, offset); 563 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 564 strides, offset); 565 } 566 567 if (bin.getKind() == AffineExprKind::Add) { 568 auto res1 = 569 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 570 auto res2 = 571 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 572 return success(succeeded(res1) && succeeded(res2)); 573 } 574 575 llvm_unreachable("unexpected binary operation"); 576 } 577 578 LogicalResult mlir::getStridesAndOffset(MemRefType t, 579 SmallVectorImpl<AffineExpr> &strides, 580 AffineExpr &offset) { 581 auto affineMaps = t.getAffineMaps(); 582 // For now strides are only computed on a single affine map with a single 583 // result (i.e. the closed subset of linearization maps that are compatible 584 // with striding semantics). 585 // TODO: support more forms on a per-need basis. 586 if (affineMaps.size() > 1) 587 return failure(); 588 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 589 return failure(); 590 591 auto zero = getAffineConstantExpr(0, t.getContext()); 592 auto one = getAffineConstantExpr(1, t.getContext()); 593 offset = zero; 594 strides.assign(t.getRank(), zero); 595 596 AffineMap m; 597 if (!affineMaps.empty()) { 598 m = affineMaps.front(); 599 assert(!m.isIdentity() && "unexpected identity map"); 600 } 601 602 // Canonical case for empty map. 603 if (!m) { 604 // 0-D corner case, offset is already 0. 605 if (t.getRank() == 0) 606 return success(); 607 auto stridedExpr = 608 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 609 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 610 return success(); 611 assert(false && "unexpected failure: extract strides in canonical layout"); 612 } 613 614 // Non-canonical case requires more work. 615 auto stridedExpr = 616 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 617 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 618 offset = AffineExpr(); 619 strides.clear(); 620 return failure(); 621 } 622 623 // Simplify results to allow folding to constants and simple checks. 624 unsigned numDims = m.getNumDims(); 625 unsigned numSymbols = m.getNumSymbols(); 626 offset = simplifyAffineExpr(offset, numDims, numSymbols); 627 for (auto &stride : strides) 628 stride = simplifyAffineExpr(stride, numDims, numSymbols); 629 630 /// In practice, a strided memref must be internally non-aliasing. Test 631 /// against 0 as a proxy. 632 /// TODO: static cases can have more advanced checks. 633 /// TODO: dynamic cases would require a way to compare symbolic 634 /// expressions and would probably need an affine set context propagated 635 /// everywhere. 636 if (llvm::any_of(strides, [](AffineExpr e) { 637 return e == getAffineConstantExpr(0, e.getContext()); 638 })) { 639 offset = AffineExpr(); 640 strides.clear(); 641 return failure(); 642 } 643 644 return success(); 645 } 646 647 LogicalResult mlir::getStridesAndOffset(MemRefType t, 648 SmallVectorImpl<int64_t> &strides, 649 int64_t &offset) { 650 AffineExpr offsetExpr; 651 SmallVector<AffineExpr, 4> strideExprs; 652 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 653 return failure(); 654 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 655 offset = cst.getValue(); 656 else 657 offset = ShapedType::kDynamicStrideOrOffset; 658 for (auto e : strideExprs) { 659 if (auto c = e.dyn_cast<AffineConstantExpr>()) 660 strides.push_back(c.getValue()); 661 else 662 strides.push_back(ShapedType::kDynamicStrideOrOffset); 663 } 664 return success(); 665 } 666 667 //===----------------------------------------------------------------------===// 668 /// TupleType 669 //===----------------------------------------------------------------------===// 670 671 /// Return the elements types for this tuple. 672 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 673 674 /// Accumulate the types contained in this tuple and tuples nested within it. 675 /// Note that this only flattens nested tuples, not any other container type, 676 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 677 /// (i32, tensor<i32>, f32, i64) 678 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 679 for (Type type : getTypes()) { 680 if (auto nestedTuple = type.dyn_cast<TupleType>()) 681 nestedTuple.getFlattenedTypes(types); 682 else 683 types.push_back(type); 684 } 685 } 686 687 /// Return the number of element types. 688 size_t TupleType::size() const { return getImpl()->size(); } 689 690 //===----------------------------------------------------------------------===// 691 // Type Utilities 692 //===----------------------------------------------------------------------===// 693 694 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 695 int64_t offset, 696 MLIRContext *context) { 697 AffineExpr expr; 698 unsigned nSymbols = 0; 699 700 // AffineExpr for offset. 701 // Static case. 702 if (offset != MemRefType::getDynamicStrideOrOffset()) { 703 auto cst = getAffineConstantExpr(offset, context); 704 expr = cst; 705 } else { 706 // Dynamic case, new symbol for the offset. 707 auto sym = getAffineSymbolExpr(nSymbols++, context); 708 expr = sym; 709 } 710 711 // AffineExpr for strides. 712 for (auto en : llvm::enumerate(strides)) { 713 auto dim = en.index(); 714 auto stride = en.value(); 715 assert(stride != 0 && "Invalid stride specification"); 716 auto d = getAffineDimExpr(dim, context); 717 AffineExpr mult; 718 // Static case. 719 if (stride != MemRefType::getDynamicStrideOrOffset()) 720 mult = getAffineConstantExpr(stride, context); 721 else 722 // Dynamic case, new symbol for each new stride. 723 mult = getAffineSymbolExpr(nSymbols++, context); 724 expr = expr + d * mult; 725 } 726 727 return AffineMap::get(strides.size(), nSymbols, expr); 728 } 729 730 /// Return a version of `t` with identity layout if it can be determined 731 /// statically that the layout is the canonical contiguous strided layout. 732 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 733 /// `t` with simplified layout. 734 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 735 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 736 auto affineMaps = t.getAffineMaps(); 737 // Already in canonical form. 738 if (affineMaps.empty()) 739 return t; 740 741 // Can't reduce to canonical identity form, return in canonical form. 742 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 743 return t; 744 745 // If the canonical strided layout for the sizes of `t` is equal to the 746 // simplified layout of `t` we can just return an empty layout. Otherwise, 747 // just simplify the existing layout. 748 AffineExpr expr = 749 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 750 auto m = affineMaps[0]; 751 auto simplifiedLayoutExpr = 752 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 753 if (expr != simplifiedLayoutExpr) 754 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 755 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 756 return MemRefType::Builder(t).setAffineMaps({}); 757 } 758 759 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 760 ArrayRef<AffineExpr> exprs, 761 MLIRContext *context) { 762 // Size 0 corner case is useful for canonicalizations. 763 if (llvm::is_contained(sizes, 0)) 764 return getAffineConstantExpr(0, context); 765 766 auto maps = AffineMap::inferFromExprList(exprs); 767 assert(!maps.empty() && "Expected one non-empty map"); 768 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 769 770 AffineExpr expr; 771 bool dynamicPoisonBit = false; 772 int64_t runningSize = 1; 773 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 774 int64_t size = std::get<1>(en); 775 // Degenerate case, no size =-> no stride 776 if (size == 0) 777 continue; 778 AffineExpr dimExpr = std::get<0>(en); 779 AffineExpr stride = dynamicPoisonBit 780 ? getAffineSymbolExpr(nSymbols++, context) 781 : getAffineConstantExpr(runningSize, context); 782 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 783 if (size > 0) 784 runningSize *= size; 785 else 786 dynamicPoisonBit = true; 787 } 788 return simplifyAffineExpr(expr, numDims, nSymbols); 789 } 790 791 /// Return a version of `t` with a layout that has all dynamic offset and 792 /// strides. This is used to erase the static layout. 793 MemRefType mlir::eraseStridedLayout(MemRefType t) { 794 auto val = ShapedType::kDynamicStrideOrOffset; 795 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 796 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 797 } 798 799 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 800 MLIRContext *context) { 801 SmallVector<AffineExpr, 4> exprs; 802 exprs.reserve(sizes.size()); 803 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 804 exprs.push_back(getAffineDimExpr(dim, context)); 805 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 806 } 807 808 /// Return true if the layout for `t` is compatible with strided semantics. 809 bool mlir::isStrided(MemRefType t) { 810 int64_t offset; 811 SmallVector<int64_t, 4> stridesAndOffset; 812 auto res = getStridesAndOffset(t, stridesAndOffset, offset); 813 return succeeded(res); 814 } 815