1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/TensorEncoding.h" 18 #include "llvm/ADT/APFloat.h" 19 #include "llvm/ADT/BitVector.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/Twine.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 using namespace mlir; 25 using namespace mlir::detail; 26 27 //===----------------------------------------------------------------------===// 28 /// Tablegen Type Definitions 29 //===----------------------------------------------------------------------===// 30 31 #define GET_TYPEDEF_CLASSES 32 #include "mlir/IR/BuiltinTypes.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 /// Tablegen Interface Definitions 36 //===----------------------------------------------------------------------===// 37 38 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" 39 40 //===----------------------------------------------------------------------===// 41 // BuiltinDialect 42 //===----------------------------------------------------------------------===// 43 44 void BuiltinDialect::registerTypes() { 45 addTypes< 46 #define GET_TYPEDEF_LIST 47 #include "mlir/IR/BuiltinTypes.cpp.inc" 48 >(); 49 } 50 51 //===----------------------------------------------------------------------===// 52 /// ComplexType 53 //===----------------------------------------------------------------------===// 54 55 /// Verify the construction of an integer type. 56 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 57 Type elementType) { 58 if (!elementType.isIntOrFloat()) 59 return emitError() << "invalid element type for complex"; 60 return success(); 61 } 62 63 //===----------------------------------------------------------------------===// 64 // Integer Type 65 //===----------------------------------------------------------------------===// 66 67 // static constexpr must have a definition (until in C++17 and inline variable). 68 constexpr unsigned IntegerType::kMaxWidth; 69 70 /// Verify the construction of an integer type. 71 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 72 unsigned width, 73 SignednessSemantics signedness) { 74 if (width > IntegerType::kMaxWidth) { 75 return emitError() << "integer bitwidth is limited to " 76 << IntegerType::kMaxWidth << " bits"; 77 } 78 return success(); 79 } 80 81 unsigned IntegerType::getWidth() const { return getImpl()->width; } 82 83 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 84 return getImpl()->signedness; 85 } 86 87 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 88 if (!scale) 89 return IntegerType(); 90 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // Float Type 95 //===----------------------------------------------------------------------===// 96 97 unsigned FloatType::getWidth() { 98 if (isa<Float16Type, BFloat16Type>()) 99 return 16; 100 if (isa<Float32Type>()) 101 return 32; 102 if (isa<Float64Type>()) 103 return 64; 104 if (isa<Float80Type>()) 105 return 80; 106 if (isa<Float128Type>()) 107 return 128; 108 llvm_unreachable("unexpected float type"); 109 } 110 111 /// Returns the floating semantics for the given type. 112 const llvm::fltSemantics &FloatType::getFloatSemantics() { 113 if (isa<BFloat16Type>()) 114 return APFloat::BFloat(); 115 if (isa<Float16Type>()) 116 return APFloat::IEEEhalf(); 117 if (isa<Float32Type>()) 118 return APFloat::IEEEsingle(); 119 if (isa<Float64Type>()) 120 return APFloat::IEEEdouble(); 121 if (isa<Float80Type>()) 122 return APFloat::x87DoubleExtended(); 123 if (isa<Float128Type>()) 124 return APFloat::IEEEquad(); 125 llvm_unreachable("non-floating point type used"); 126 } 127 128 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 129 if (!scale) 130 return FloatType(); 131 MLIRContext *ctx = getContext(); 132 if (isF16() || isBF16()) { 133 if (scale == 2) 134 return FloatType::getF32(ctx); 135 if (scale == 4) 136 return FloatType::getF64(ctx); 137 } 138 if (isF32()) 139 if (scale == 2) 140 return FloatType::getF64(ctx); 141 return FloatType(); 142 } 143 144 //===----------------------------------------------------------------------===// 145 // FunctionType 146 //===----------------------------------------------------------------------===// 147 148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 149 150 ArrayRef<Type> FunctionType::getInputs() const { 151 return getImpl()->getInputs(); 152 } 153 154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 155 156 ArrayRef<Type> FunctionType::getResults() const { 157 return getImpl()->getResults(); 158 } 159 160 /// Helper to call a callback once on each index in the range 161 /// [0, `totalIndices`), *except* for the indices given in `indices`. 162 /// `indices` is allowed to have duplicates and can be in any order. 163 inline void iterateIndicesExcept(unsigned totalIndices, 164 ArrayRef<unsigned> indices, 165 function_ref<void(unsigned)> callback) { 166 llvm::BitVector skipIndices(totalIndices); 167 for (unsigned i : indices) 168 skipIndices.set(i); 169 170 for (unsigned i = 0; i < totalIndices; ++i) 171 if (!skipIndices.test(i)) 172 callback(i); 173 } 174 175 /// Returns a new function type without the specified arguments and results. 176 FunctionType 177 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 178 ArrayRef<unsigned> resultIndices) { 179 ArrayRef<Type> newInputTypes = getInputs(); 180 SmallVector<Type, 4> newInputTypesBuffer; 181 if (!argIndices.empty()) { 182 unsigned originalNumArgs = getNumInputs(); 183 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 184 newInputTypesBuffer.emplace_back(getInput(i)); 185 }); 186 newInputTypes = newInputTypesBuffer; 187 } 188 189 ArrayRef<Type> newResultTypes = getResults(); 190 SmallVector<Type, 4> newResultTypesBuffer; 191 if (!resultIndices.empty()) { 192 unsigned originalNumResults = getNumResults(); 193 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 194 newResultTypesBuffer.emplace_back(getResult(i)); 195 }); 196 newResultTypes = newResultTypesBuffer; 197 } 198 199 return get(getContext(), newInputTypes, newResultTypes); 200 } 201 202 void FunctionType::walkImmediateSubElements( 203 function_ref<void(Attribute)> walkAttrsFn, 204 function_ref<void(Type)> walkTypesFn) const { 205 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 206 walkTypesFn(type); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // OpaqueType 211 //===----------------------------------------------------------------------===// 212 213 /// Verify the construction of an opaque type. 214 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 215 Identifier dialect, StringRef typeData) { 216 if (!Dialect::isValidNamespace(dialect.strref())) 217 return emitError() << "invalid dialect namespace '" << dialect << "'"; 218 219 // Check that the dialect is actually registered. 220 MLIRContext *context = dialect.getContext(); 221 if (!context->allowsUnregisteredDialects() && 222 !context->getLoadedDialect(dialect.strref())) { 223 return emitError() 224 << "`!" << dialect << "<\"" << typeData << "\">" 225 << "` type created with unregistered dialect. If this is " 226 "intended, please call allowUnregisteredDialects() on the " 227 "MLIRContext, or use -allow-unregistered-dialect with " 228 "mlir-opt"; 229 } 230 231 return success(); 232 } 233 234 //===----------------------------------------------------------------------===// 235 // ShapedType 236 //===----------------------------------------------------------------------===// 237 constexpr int64_t ShapedType::kDynamicSize; 238 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 239 240 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 241 if (auto other = dyn_cast<MemRefType>()) { 242 MemRefType::Builder b(other); 243 b.setShape(shape); 244 b.setElementType(elementType); 245 return b; 246 } 247 248 if (auto other = dyn_cast<UnrankedMemRefType>()) { 249 MemRefType::Builder b(shape, elementType); 250 b.setMemorySpace(other.getMemorySpace()); 251 return b; 252 } 253 254 if (isa<TensorType>()) 255 return RankedTensorType::get(shape, elementType); 256 257 if (isa<VectorType>()) 258 return VectorType::get(shape, elementType); 259 260 llvm_unreachable("Unhandled ShapedType clone case"); 261 } 262 263 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 264 if (auto other = dyn_cast<MemRefType>()) { 265 MemRefType::Builder b(other); 266 b.setShape(shape); 267 return b; 268 } 269 270 if (auto other = dyn_cast<UnrankedMemRefType>()) { 271 MemRefType::Builder b(shape, other.getElementType()); 272 b.setShape(shape); 273 b.setMemorySpace(other.getMemorySpace()); 274 return b; 275 } 276 277 if (isa<TensorType>()) 278 return RankedTensorType::get(shape, getElementType()); 279 280 if (isa<VectorType>()) 281 return VectorType::get(shape, getElementType()); 282 283 llvm_unreachable("Unhandled ShapedType clone case"); 284 } 285 286 ShapedType ShapedType::clone(Type elementType) { 287 if (auto other = dyn_cast<MemRefType>()) { 288 MemRefType::Builder b(other); 289 b.setElementType(elementType); 290 return b; 291 } 292 293 if (auto other = dyn_cast<UnrankedMemRefType>()) { 294 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 295 } 296 297 if (isa<TensorType>()) { 298 if (hasRank()) 299 return RankedTensorType::get(getShape(), elementType); 300 return UnrankedTensorType::get(elementType); 301 } 302 303 if (isa<VectorType>()) 304 return VectorType::get(getShape(), elementType); 305 306 llvm_unreachable("Unhandled ShapedType clone hit"); 307 } 308 309 Type ShapedType::getElementType() const { 310 return TypeSwitch<Type, Type>(*this) 311 .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType, 312 UnrankedMemRefType>([](auto ty) { return ty.getElementType(); }); 313 } 314 315 unsigned ShapedType::getElementTypeBitWidth() const { 316 return getElementType().getIntOrFloatBitWidth(); 317 } 318 319 int64_t ShapedType::getNumElements() const { 320 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 321 auto shape = getShape(); 322 int64_t num = 1; 323 for (auto dim : shape) { 324 num *= dim; 325 assert(num >= 0 && "integer overflow in element count computation"); 326 } 327 return num; 328 } 329 330 int64_t ShapedType::getRank() const { 331 assert(hasRank() && "cannot query rank of unranked shaped type"); 332 return getShape().size(); 333 } 334 335 bool ShapedType::hasRank() const { 336 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 337 } 338 339 int64_t ShapedType::getDimSize(unsigned idx) const { 340 assert(idx < getRank() && "invalid index for shaped type"); 341 return getShape()[idx]; 342 } 343 344 bool ShapedType::isDynamicDim(unsigned idx) const { 345 assert(idx < getRank() && "invalid index for shaped type"); 346 return isDynamic(getShape()[idx]); 347 } 348 349 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 350 assert(index < getRank() && "invalid index"); 351 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 352 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 353 } 354 355 /// Get the number of bits require to store a value of the given shaped type. 356 /// Compute the value recursively since tensors are allowed to have vectors as 357 /// elements. 358 int64_t ShapedType::getSizeInBits() const { 359 assert(hasStaticShape() && 360 "cannot get the bit size of an aggregate with a dynamic shape"); 361 362 auto elementType = getElementType(); 363 if (elementType.isIntOrFloat()) 364 return elementType.getIntOrFloatBitWidth() * getNumElements(); 365 366 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 367 elementType = complexType.getElementType(); 368 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 369 } 370 371 // Tensors can have vectors and other tensors as elements, other shaped types 372 // cannot. 373 assert(isa<TensorType>() && "unsupported element type"); 374 assert((elementType.isa<VectorType, TensorType>()) && 375 "unsupported tensor element type"); 376 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 377 } 378 379 ArrayRef<int64_t> ShapedType::getShape() const { 380 if (auto vectorType = dyn_cast<VectorType>()) 381 return vectorType.getShape(); 382 if (auto tensorType = dyn_cast<RankedTensorType>()) 383 return tensorType.getShape(); 384 return cast<MemRefType>().getShape(); 385 } 386 387 int64_t ShapedType::getNumDynamicDims() const { 388 return llvm::count_if(getShape(), isDynamic); 389 } 390 391 bool ShapedType::hasStaticShape() const { 392 return hasRank() && llvm::none_of(getShape(), isDynamic); 393 } 394 395 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 396 return hasStaticShape() && getShape() == shape; 397 } 398 399 //===----------------------------------------------------------------------===// 400 // VectorType 401 //===----------------------------------------------------------------------===// 402 403 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 404 ArrayRef<int64_t> shape, Type elementType) { 405 if (shape.empty()) 406 return emitError() << "vector types must have at least one dimension"; 407 408 if (!isValidElementType(elementType)) 409 return emitError() << "vector elements must be int/index/float type"; 410 411 if (any_of(shape, [](int64_t i) { return i <= 0; })) 412 return emitError() << "vector types must have positive constant sizes"; 413 414 return success(); 415 } 416 417 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 418 if (!scale) 419 return VectorType(); 420 if (auto et = getElementType().dyn_cast<IntegerType>()) 421 if (auto scaledEt = et.scaleElementBitwidth(scale)) 422 return VectorType::get(getShape(), scaledEt); 423 if (auto et = getElementType().dyn_cast<FloatType>()) 424 if (auto scaledEt = et.scaleElementBitwidth(scale)) 425 return VectorType::get(getShape(), scaledEt); 426 return VectorType(); 427 } 428 429 void VectorType::walkImmediateSubElements( 430 function_ref<void(Attribute)> walkAttrsFn, 431 function_ref<void(Type)> walkTypesFn) const { 432 walkTypesFn(getElementType()); 433 } 434 435 //===----------------------------------------------------------------------===// 436 // TensorType 437 //===----------------------------------------------------------------------===// 438 439 // Check if "elementType" can be an element type of a tensor. 440 static LogicalResult 441 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 442 Type elementType) { 443 if (!TensorType::isValidElementType(elementType)) 444 return emitError() << "invalid tensor element type: " << elementType; 445 return success(); 446 } 447 448 /// Return true if the specified element type is ok in a tensor. 449 bool TensorType::isValidElementType(Type type) { 450 // Note: Non standard/builtin types are allowed to exist within tensor 451 // types. Dialects are expected to verify that tensor types have a valid 452 // element type within that dialect. 453 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 454 IndexType>() || 455 !type.getDialect().getNamespace().empty(); 456 } 457 458 //===----------------------------------------------------------------------===// 459 // RankedTensorType 460 //===----------------------------------------------------------------------===// 461 462 LogicalResult 463 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 464 ArrayRef<int64_t> shape, Type elementType, 465 Attribute encoding) { 466 for (int64_t s : shape) 467 if (s < -1) 468 return emitError() << "invalid tensor dimension size"; 469 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 470 if (failed(v.verifyEncoding(shape, elementType, emitError))) 471 return failure(); 472 return checkTensorElementType(emitError, elementType); 473 } 474 475 void RankedTensorType::walkImmediateSubElements( 476 function_ref<void(Attribute)> walkAttrsFn, 477 function_ref<void(Type)> walkTypesFn) const { 478 walkTypesFn(getElementType()); 479 } 480 481 //===----------------------------------------------------------------------===// 482 // UnrankedTensorType 483 //===----------------------------------------------------------------------===// 484 485 LogicalResult 486 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 487 Type elementType) { 488 return checkTensorElementType(emitError, elementType); 489 } 490 491 void UnrankedTensorType::walkImmediateSubElements( 492 function_ref<void(Attribute)> walkAttrsFn, 493 function_ref<void(Type)> walkTypesFn) const { 494 walkTypesFn(getElementType()); 495 } 496 497 //===----------------------------------------------------------------------===// 498 // BaseMemRefType 499 //===----------------------------------------------------------------------===// 500 501 Attribute BaseMemRefType::getMemorySpace() const { 502 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 503 return rankedMemRefTy.getMemorySpace(); 504 return cast<UnrankedMemRefType>().getMemorySpace(); 505 } 506 507 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 508 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 509 return rankedMemRefTy.getMemorySpaceAsInt(); 510 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 511 } 512 513 //===----------------------------------------------------------------------===// 514 // MemRefType 515 //===----------------------------------------------------------------------===// 516 517 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 518 /// `originalShape` with some `1` entries erased, return the set of indices 519 /// that specifies which of the entries of `originalShape` are dropped to obtain 520 /// `reducedShape`. The returned mask can be applied as a projection to 521 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 522 /// which dimensions must be kept when e.g. compute MemRef strides under 523 /// rank-reducing operations. Return None if reducedShape cannot be obtained 524 /// by dropping only `1` entries in `originalShape`. 525 llvm::Optional<llvm::SmallDenseSet<unsigned>> 526 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 527 ArrayRef<int64_t> reducedShape) { 528 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 529 llvm::SmallDenseSet<unsigned> unusedDims; 530 unsigned reducedIdx = 0; 531 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 532 // Greedily insert `originalIdx` if no match. 533 if (reducedIdx < reducedRank && 534 originalShape[originalIdx] == reducedShape[reducedIdx]) { 535 reducedIdx++; 536 continue; 537 } 538 539 unusedDims.insert(originalIdx); 540 // If no match on `originalIdx`, the `originalShape` at this dimension 541 // must be 1, otherwise we bail. 542 if (originalShape[originalIdx] != 1) 543 return llvm::None; 544 } 545 // The whole reducedShape must be scanned, otherwise we bail. 546 if (reducedIdx != reducedRank) 547 return llvm::None; 548 return unusedDims; 549 } 550 551 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 552 // Empty attribute is allowed as default memory space. 553 if (!memorySpace) 554 return true; 555 556 // Supported built-in attributes. 557 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 558 return true; 559 560 // Allow custom dialect attributes. 561 if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect())) 562 return true; 563 564 return false; 565 } 566 567 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 568 MLIRContext *ctx) { 569 if (memorySpace == 0) 570 return nullptr; 571 572 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 573 } 574 575 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 576 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 577 if (intMemorySpace && intMemorySpace.getValue() == 0) 578 return nullptr; 579 580 return memorySpace; 581 } 582 583 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 584 if (!memorySpace) 585 return 0; 586 587 assert(memorySpace.isa<IntegerAttr>() && 588 "Using `getMemorySpaceInteger` with non-Integer attribute"); 589 590 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 591 } 592 593 MemRefType::Builder & 594 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 595 memorySpace = 596 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 597 return *this; 598 } 599 600 unsigned MemRefType::getMemorySpaceAsInt() const { 601 return detail::getMemorySpaceAsInt(getMemorySpace()); 602 } 603 604 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 605 ArrayRef<int64_t> shape, Type elementType, 606 ArrayRef<AffineMap> affineMapComposition, 607 Attribute memorySpace) { 608 if (!BaseMemRefType::isValidElementType(elementType)) 609 return emitError() << "invalid memref element type"; 610 611 // Negative sizes are not allowed except for `-1` that means dynamic size. 612 for (int64_t s : shape) 613 if (s < -1) 614 return emitError() << "invalid memref size"; 615 616 // Check that the structure of the composition is valid, i.e. that each 617 // subsequent affine map has as many inputs as the previous map has results. 618 // Take the dimensionality of the MemRef for the first map. 619 size_t dim = shape.size(); 620 for (auto it : llvm::enumerate(affineMapComposition)) { 621 AffineMap map = it.value(); 622 if (map.getNumDims() == dim) { 623 dim = map.getNumResults(); 624 continue; 625 } 626 return emitError() << "memref affine map dimension mismatch between " 627 << (it.index() == 0 ? Twine("memref rank") 628 : "affine map " + Twine(it.index())) 629 << " and affine map" << it.index() + 1 << ": " << dim 630 << " != " << map.getNumDims(); 631 } 632 633 if (!isSupportedMemorySpace(memorySpace)) { 634 return emitError() << "unsupported memory space Attribute"; 635 } 636 637 return success(); 638 } 639 640 void MemRefType::walkImmediateSubElements( 641 function_ref<void(Attribute)> walkAttrsFn, 642 function_ref<void(Type)> walkTypesFn) const { 643 walkTypesFn(getElementType()); 644 walkAttrsFn(getMemorySpace()); 645 for (AffineMap map : getAffineMaps()) 646 walkAttrsFn(AffineMapAttr::get(map)); 647 } 648 649 //===----------------------------------------------------------------------===// 650 // UnrankedMemRefType 651 //===----------------------------------------------------------------------===// 652 653 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 654 return detail::getMemorySpaceAsInt(getMemorySpace()); 655 } 656 657 LogicalResult 658 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 659 Type elementType, Attribute memorySpace) { 660 if (!BaseMemRefType::isValidElementType(elementType)) 661 return emitError() << "invalid memref element type"; 662 663 if (!isSupportedMemorySpace(memorySpace)) 664 return emitError() << "unsupported memory space Attribute"; 665 666 return success(); 667 } 668 669 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 670 // i.e. single term). Accumulate the AffineExpr into the existing one. 671 static void extractStridesFromTerm(AffineExpr e, 672 AffineExpr multiplicativeFactor, 673 MutableArrayRef<AffineExpr> strides, 674 AffineExpr &offset) { 675 if (auto dim = e.dyn_cast<AffineDimExpr>()) 676 strides[dim.getPosition()] = 677 strides[dim.getPosition()] + multiplicativeFactor; 678 else 679 offset = offset + e * multiplicativeFactor; 680 } 681 682 /// Takes a single AffineExpr `e` and populates the `strides` array with the 683 /// strides expressions for each dim position. 684 /// The convention is that the strides for dimensions d0, .. dn appear in 685 /// order to make indexing intuitive into the result. 686 static LogicalResult extractStrides(AffineExpr e, 687 AffineExpr multiplicativeFactor, 688 MutableArrayRef<AffineExpr> strides, 689 AffineExpr &offset) { 690 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 691 if (!bin) { 692 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 693 return success(); 694 } 695 696 if (bin.getKind() == AffineExprKind::CeilDiv || 697 bin.getKind() == AffineExprKind::FloorDiv || 698 bin.getKind() == AffineExprKind::Mod) 699 return failure(); 700 701 if (bin.getKind() == AffineExprKind::Mul) { 702 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 703 if (dim) { 704 strides[dim.getPosition()] = 705 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 706 return success(); 707 } 708 // LHS and RHS may both contain complex expressions of dims. Try one path 709 // and if it fails try the other. This is guaranteed to succeed because 710 // only one path may have a `dim`, otherwise this is not an AffineExpr in 711 // the first place. 712 if (bin.getLHS().isSymbolicOrConstant()) 713 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 714 strides, offset); 715 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 716 strides, offset); 717 } 718 719 if (bin.getKind() == AffineExprKind::Add) { 720 auto res1 = 721 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 722 auto res2 = 723 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 724 return success(succeeded(res1) && succeeded(res2)); 725 } 726 727 llvm_unreachable("unexpected binary operation"); 728 } 729 730 LogicalResult mlir::getStridesAndOffset(MemRefType t, 731 SmallVectorImpl<AffineExpr> &strides, 732 AffineExpr &offset) { 733 auto affineMaps = t.getAffineMaps(); 734 735 if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1) 736 return failure(); 737 738 AffineMap m; 739 if (!affineMaps.empty()) { 740 m = affineMaps.back(); 741 for (size_t i = affineMaps.size() - 1; i > 0; --i) 742 m = m.compose(affineMaps[i - 1]); 743 assert(!m.isIdentity() && "unexpected identity map"); 744 } 745 746 auto zero = getAffineConstantExpr(0, t.getContext()); 747 auto one = getAffineConstantExpr(1, t.getContext()); 748 offset = zero; 749 strides.assign(t.getRank(), zero); 750 751 // Canonical case for empty map. 752 if (!m) { 753 // 0-D corner case, offset is already 0. 754 if (t.getRank() == 0) 755 return success(); 756 auto stridedExpr = 757 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 758 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 759 return success(); 760 assert(false && "unexpected failure: extract strides in canonical layout"); 761 } 762 763 // Non-canonical case requires more work. 764 auto stridedExpr = 765 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 766 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 767 offset = AffineExpr(); 768 strides.clear(); 769 return failure(); 770 } 771 772 // Simplify results to allow folding to constants and simple checks. 773 unsigned numDims = m.getNumDims(); 774 unsigned numSymbols = m.getNumSymbols(); 775 offset = simplifyAffineExpr(offset, numDims, numSymbols); 776 for (auto &stride : strides) 777 stride = simplifyAffineExpr(stride, numDims, numSymbols); 778 779 /// In practice, a strided memref must be internally non-aliasing. Test 780 /// against 0 as a proxy. 781 /// TODO: static cases can have more advanced checks. 782 /// TODO: dynamic cases would require a way to compare symbolic 783 /// expressions and would probably need an affine set context propagated 784 /// everywhere. 785 if (llvm::any_of(strides, [](AffineExpr e) { 786 return e == getAffineConstantExpr(0, e.getContext()); 787 })) { 788 offset = AffineExpr(); 789 strides.clear(); 790 return failure(); 791 } 792 793 return success(); 794 } 795 796 LogicalResult mlir::getStridesAndOffset(MemRefType t, 797 SmallVectorImpl<int64_t> &strides, 798 int64_t &offset) { 799 AffineExpr offsetExpr; 800 SmallVector<AffineExpr, 4> strideExprs; 801 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 802 return failure(); 803 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 804 offset = cst.getValue(); 805 else 806 offset = ShapedType::kDynamicStrideOrOffset; 807 for (auto e : strideExprs) { 808 if (auto c = e.dyn_cast<AffineConstantExpr>()) 809 strides.push_back(c.getValue()); 810 else 811 strides.push_back(ShapedType::kDynamicStrideOrOffset); 812 } 813 return success(); 814 } 815 816 void UnrankedMemRefType::walkImmediateSubElements( 817 function_ref<void(Attribute)> walkAttrsFn, 818 function_ref<void(Type)> walkTypesFn) const { 819 walkTypesFn(getElementType()); 820 walkAttrsFn(getMemorySpace()); 821 } 822 823 //===----------------------------------------------------------------------===// 824 /// TupleType 825 //===----------------------------------------------------------------------===// 826 827 /// Return the elements types for this tuple. 828 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 829 830 /// Accumulate the types contained in this tuple and tuples nested within it. 831 /// Note that this only flattens nested tuples, not any other container type, 832 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 833 /// (i32, tensor<i32>, f32, i64) 834 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 835 for (Type type : getTypes()) { 836 if (auto nestedTuple = type.dyn_cast<TupleType>()) 837 nestedTuple.getFlattenedTypes(types); 838 else 839 types.push_back(type); 840 } 841 } 842 843 /// Return the number of element types. 844 size_t TupleType::size() const { return getImpl()->size(); } 845 846 void TupleType::walkImmediateSubElements( 847 function_ref<void(Attribute)> walkAttrsFn, 848 function_ref<void(Type)> walkTypesFn) const { 849 for (Type type : getTypes()) 850 walkTypesFn(type); 851 } 852 853 //===----------------------------------------------------------------------===// 854 // Type Utilities 855 //===----------------------------------------------------------------------===// 856 857 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 858 int64_t offset, 859 MLIRContext *context) { 860 AffineExpr expr; 861 unsigned nSymbols = 0; 862 863 // AffineExpr for offset. 864 // Static case. 865 if (offset != MemRefType::getDynamicStrideOrOffset()) { 866 auto cst = getAffineConstantExpr(offset, context); 867 expr = cst; 868 } else { 869 // Dynamic case, new symbol for the offset. 870 auto sym = getAffineSymbolExpr(nSymbols++, context); 871 expr = sym; 872 } 873 874 // AffineExpr for strides. 875 for (auto en : llvm::enumerate(strides)) { 876 auto dim = en.index(); 877 auto stride = en.value(); 878 assert(stride != 0 && "Invalid stride specification"); 879 auto d = getAffineDimExpr(dim, context); 880 AffineExpr mult; 881 // Static case. 882 if (stride != MemRefType::getDynamicStrideOrOffset()) 883 mult = getAffineConstantExpr(stride, context); 884 else 885 // Dynamic case, new symbol for each new stride. 886 mult = getAffineSymbolExpr(nSymbols++, context); 887 expr = expr + d * mult; 888 } 889 890 return AffineMap::get(strides.size(), nSymbols, expr); 891 } 892 893 /// Return a version of `t` with identity layout if it can be determined 894 /// statically that the layout is the canonical contiguous strided layout. 895 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 896 /// `t` with simplified layout. 897 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 898 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 899 auto affineMaps = t.getAffineMaps(); 900 // Already in canonical form. 901 if (affineMaps.empty()) 902 return t; 903 904 // Can't reduce to canonical identity form, return in canonical form. 905 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 906 return t; 907 908 // Corner-case for 0-D affine maps. 909 auto m = affineMaps[0]; 910 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 911 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 912 if (cst.getValue() == 0) 913 return MemRefType::Builder(t).setAffineMaps({}); 914 return t; 915 } 916 917 // 0-D corner case for empty shape that still have an affine map. Example: 918 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 919 // offset needs to remain, just return t. 920 if (t.getShape().empty()) 921 return t; 922 923 // If the canonical strided layout for the sizes of `t` is equal to the 924 // simplified layout of `t` we can just return an empty layout. Otherwise, 925 // just simplify the existing layout. 926 AffineExpr expr = 927 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 928 auto simplifiedLayoutExpr = 929 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 930 if (expr != simplifiedLayoutExpr) 931 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 932 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 933 return MemRefType::Builder(t).setAffineMaps({}); 934 } 935 936 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 937 ArrayRef<AffineExpr> exprs, 938 MLIRContext *context) { 939 assert(!sizes.empty() && !exprs.empty() && 940 "expected non-empty sizes and exprs"); 941 942 // Size 0 corner case is useful for canonicalizations. 943 if (llvm::is_contained(sizes, 0)) 944 return getAffineConstantExpr(0, context); 945 946 auto maps = AffineMap::inferFromExprList(exprs); 947 assert(!maps.empty() && "Expected one non-empty map"); 948 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 949 950 AffineExpr expr; 951 bool dynamicPoisonBit = false; 952 int64_t runningSize = 1; 953 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 954 int64_t size = std::get<1>(en); 955 // Degenerate case, no size =-> no stride 956 if (size == 0) 957 continue; 958 AffineExpr dimExpr = std::get<0>(en); 959 AffineExpr stride = dynamicPoisonBit 960 ? getAffineSymbolExpr(nSymbols++, context) 961 : getAffineConstantExpr(runningSize, context); 962 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 963 if (size > 0) { 964 runningSize *= size; 965 assert(runningSize > 0 && "integer overflow in size computation"); 966 } else { 967 dynamicPoisonBit = true; 968 } 969 } 970 return simplifyAffineExpr(expr, numDims, nSymbols); 971 } 972 973 /// Return a version of `t` with a layout that has all dynamic offset and 974 /// strides. This is used to erase the static layout. 975 MemRefType mlir::eraseStridedLayout(MemRefType t) { 976 auto val = ShapedType::kDynamicStrideOrOffset; 977 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 978 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 979 } 980 981 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 982 MLIRContext *context) { 983 SmallVector<AffineExpr, 4> exprs; 984 exprs.reserve(sizes.size()); 985 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 986 exprs.push_back(getAffineDimExpr(dim, context)); 987 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 988 } 989 990 /// Return true if the layout for `t` is compatible with strided semantics. 991 bool mlir::isStrided(MemRefType t) { 992 int64_t offset; 993 SmallVector<int64_t, 4> strides; 994 auto res = getStridesAndOffset(t, strides, offset); 995 return succeeded(res); 996 } 997 998 /// Return the layout map in strided linear layout AffineMap form. 999 /// Return null if the layout is not compatible with a strided layout. 1000 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1001 int64_t offset; 1002 SmallVector<int64_t, 4> strides; 1003 if (failed(getStridesAndOffset(t, strides, offset))) 1004 return AffineMap(); 1005 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1006 } 1007