1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/TensorEncoding.h" 18 #include "llvm/ADT/APFloat.h" 19 #include "llvm/ADT/BitVector.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/Twine.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 using namespace mlir; 25 using namespace mlir::detail; 26 27 //===----------------------------------------------------------------------===// 28 /// Tablegen Type Definitions 29 //===----------------------------------------------------------------------===// 30 31 #define GET_TYPEDEF_CLASSES 32 #include "mlir/IR/BuiltinTypes.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 /// Tablegen Interface Definitions 36 //===----------------------------------------------------------------------===// 37 38 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" 39 40 //===----------------------------------------------------------------------===// 41 // BuiltinDialect 42 //===----------------------------------------------------------------------===// 43 44 void BuiltinDialect::registerTypes() { 45 addTypes< 46 #define GET_TYPEDEF_LIST 47 #include "mlir/IR/BuiltinTypes.cpp.inc" 48 >(); 49 } 50 51 //===----------------------------------------------------------------------===// 52 /// ComplexType 53 //===----------------------------------------------------------------------===// 54 55 /// Verify the construction of an integer type. 56 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 57 Type elementType) { 58 if (!elementType.isIntOrFloat()) 59 return emitError() << "invalid element type for complex"; 60 return success(); 61 } 62 63 //===----------------------------------------------------------------------===// 64 // Integer Type 65 //===----------------------------------------------------------------------===// 66 67 // static constexpr must have a definition (until in C++17 and inline variable). 68 constexpr unsigned IntegerType::kMaxWidth; 69 70 /// Verify the construction of an integer type. 71 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 72 unsigned width, 73 SignednessSemantics signedness) { 74 if (width > IntegerType::kMaxWidth) { 75 return emitError() << "integer bitwidth is limited to " 76 << IntegerType::kMaxWidth << " bits"; 77 } 78 return success(); 79 } 80 81 unsigned IntegerType::getWidth() const { return getImpl()->width; } 82 83 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 84 return getImpl()->signedness; 85 } 86 87 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 88 if (!scale) 89 return IntegerType(); 90 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // Float Type 95 //===----------------------------------------------------------------------===// 96 97 unsigned FloatType::getWidth() { 98 if (isa<Float16Type, BFloat16Type>()) 99 return 16; 100 if (isa<Float32Type>()) 101 return 32; 102 if (isa<Float64Type>()) 103 return 64; 104 if (isa<Float80Type>()) 105 return 80; 106 if (isa<Float128Type>()) 107 return 128; 108 llvm_unreachable("unexpected float type"); 109 } 110 111 /// Returns the floating semantics for the given type. 112 const llvm::fltSemantics &FloatType::getFloatSemantics() { 113 if (isa<BFloat16Type>()) 114 return APFloat::BFloat(); 115 if (isa<Float16Type>()) 116 return APFloat::IEEEhalf(); 117 if (isa<Float32Type>()) 118 return APFloat::IEEEsingle(); 119 if (isa<Float64Type>()) 120 return APFloat::IEEEdouble(); 121 if (isa<Float80Type>()) 122 return APFloat::x87DoubleExtended(); 123 if (isa<Float128Type>()) 124 return APFloat::IEEEquad(); 125 llvm_unreachable("non-floating point type used"); 126 } 127 128 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 129 if (!scale) 130 return FloatType(); 131 MLIRContext *ctx = getContext(); 132 if (isF16() || isBF16()) { 133 if (scale == 2) 134 return FloatType::getF32(ctx); 135 if (scale == 4) 136 return FloatType::getF64(ctx); 137 } 138 if (isF32()) 139 if (scale == 2) 140 return FloatType::getF64(ctx); 141 return FloatType(); 142 } 143 144 //===----------------------------------------------------------------------===// 145 // FunctionType 146 //===----------------------------------------------------------------------===// 147 148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 149 150 ArrayRef<Type> FunctionType::getInputs() const { 151 return getImpl()->getInputs(); 152 } 153 154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 155 156 ArrayRef<Type> FunctionType::getResults() const { 157 return getImpl()->getResults(); 158 } 159 160 /// Helper to call a callback once on each index in the range 161 /// [0, `totalIndices`), *except* for the indices given in `indices`. 162 /// `indices` is allowed to have duplicates and can be in any order. 163 inline void iterateIndicesExcept(unsigned totalIndices, 164 ArrayRef<unsigned> indices, 165 function_ref<void(unsigned)> callback) { 166 llvm::BitVector skipIndices(totalIndices); 167 for (unsigned i : indices) 168 skipIndices.set(i); 169 170 for (unsigned i = 0; i < totalIndices; ++i) 171 if (!skipIndices.test(i)) 172 callback(i); 173 } 174 175 /// Returns a new function type with the specified arguments and results 176 /// inserted. 177 FunctionType FunctionType::getWithArgsAndResults( 178 ArrayRef<unsigned> argIndices, TypeRange argTypes, 179 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 180 assert(argIndices.size() == argTypes.size()); 181 assert(resultIndices.size() == resultTypes.size()); 182 183 ArrayRef<Type> newInputTypes = getInputs(); 184 SmallVector<Type, 4> newInputTypesBuffer; 185 if (!argIndices.empty()) { 186 const auto *fromIt = newInputTypes.begin(); 187 for (auto it : llvm::zip(argIndices, argTypes)) { 188 const auto *toIt = newInputTypes.begin() + std::get<0>(it); 189 newInputTypesBuffer.append(fromIt, toIt); 190 newInputTypesBuffer.push_back(std::get<1>(it)); 191 fromIt = toIt; 192 } 193 newInputTypesBuffer.append(fromIt, newInputTypes.end()); 194 newInputTypes = newInputTypesBuffer; 195 } 196 197 ArrayRef<Type> newResultTypes = getResults(); 198 SmallVector<Type, 4> newResultTypesBuffer; 199 if (!resultIndices.empty()) { 200 const auto *fromIt = newResultTypes.begin(); 201 for (auto it : llvm::zip(resultIndices, resultTypes)) { 202 const auto *toIt = newResultTypes.begin() + std::get<0>(it); 203 newResultTypesBuffer.append(fromIt, toIt); 204 newResultTypesBuffer.push_back(std::get<1>(it)); 205 fromIt = toIt; 206 } 207 newResultTypesBuffer.append(fromIt, newResultTypes.end()); 208 newResultTypes = newResultTypesBuffer; 209 } 210 211 return FunctionType::get(getContext(), newInputTypes, newResultTypes); 212 } 213 214 /// Returns a new function type without the specified arguments and results. 215 FunctionType 216 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 217 ArrayRef<unsigned> resultIndices) { 218 ArrayRef<Type> newInputTypes = getInputs(); 219 SmallVector<Type, 4> newInputTypesBuffer; 220 if (!argIndices.empty()) { 221 unsigned originalNumArgs = getNumInputs(); 222 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 223 newInputTypesBuffer.emplace_back(getInput(i)); 224 }); 225 newInputTypes = newInputTypesBuffer; 226 } 227 228 ArrayRef<Type> newResultTypes = getResults(); 229 SmallVector<Type, 4> newResultTypesBuffer; 230 if (!resultIndices.empty()) { 231 unsigned originalNumResults = getNumResults(); 232 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 233 newResultTypesBuffer.emplace_back(getResult(i)); 234 }); 235 newResultTypes = newResultTypesBuffer; 236 } 237 238 return get(getContext(), newInputTypes, newResultTypes); 239 } 240 241 void FunctionType::walkImmediateSubElements( 242 function_ref<void(Attribute)> walkAttrsFn, 243 function_ref<void(Type)> walkTypesFn) const { 244 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 245 walkTypesFn(type); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // OpaqueType 250 //===----------------------------------------------------------------------===// 251 252 /// Verify the construction of an opaque type. 253 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 254 StringAttr dialect, StringRef typeData) { 255 if (!Dialect::isValidNamespace(dialect.strref())) 256 return emitError() << "invalid dialect namespace '" << dialect << "'"; 257 258 // Check that the dialect is actually registered. 259 MLIRContext *context = dialect.getContext(); 260 if (!context->allowsUnregisteredDialects() && 261 !context->getLoadedDialect(dialect.strref())) { 262 return emitError() 263 << "`!" << dialect << "<\"" << typeData << "\">" 264 << "` type created with unregistered dialect. If this is " 265 "intended, please call allowUnregisteredDialects() on the " 266 "MLIRContext, or use -allow-unregistered-dialect with " 267 "the MLIR opt tool used"; 268 } 269 270 return success(); 271 } 272 273 //===----------------------------------------------------------------------===// 274 // ShapedType 275 //===----------------------------------------------------------------------===// 276 constexpr int64_t ShapedType::kDynamicSize; 277 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 278 279 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 280 if (auto other = dyn_cast<MemRefType>()) { 281 MemRefType::Builder b(other); 282 b.setShape(shape); 283 b.setElementType(elementType); 284 return b; 285 } 286 287 if (auto other = dyn_cast<UnrankedMemRefType>()) { 288 MemRefType::Builder b(shape, elementType); 289 b.setMemorySpace(other.getMemorySpace()); 290 return b; 291 } 292 293 if (isa<TensorType>()) 294 return RankedTensorType::get(shape, elementType); 295 296 if (isa<VectorType>()) 297 return VectorType::get(shape, elementType); 298 299 llvm_unreachable("Unhandled ShapedType clone case"); 300 } 301 302 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 303 if (auto other = dyn_cast<MemRefType>()) { 304 MemRefType::Builder b(other); 305 b.setShape(shape); 306 return b; 307 } 308 309 if (auto other = dyn_cast<UnrankedMemRefType>()) { 310 MemRefType::Builder b(shape, other.getElementType()); 311 b.setShape(shape); 312 b.setMemorySpace(other.getMemorySpace()); 313 return b; 314 } 315 316 if (isa<TensorType>()) 317 return RankedTensorType::get(shape, getElementType()); 318 319 if (isa<VectorType>()) 320 return VectorType::get(shape, getElementType()); 321 322 llvm_unreachable("Unhandled ShapedType clone case"); 323 } 324 325 ShapedType ShapedType::clone(Type elementType) { 326 if (auto other = dyn_cast<MemRefType>()) { 327 MemRefType::Builder b(other); 328 b.setElementType(elementType); 329 return b; 330 } 331 332 if (auto other = dyn_cast<UnrankedMemRefType>()) { 333 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 334 } 335 336 if (isa<TensorType>()) { 337 if (hasRank()) 338 return RankedTensorType::get(getShape(), elementType); 339 return UnrankedTensorType::get(elementType); 340 } 341 342 if (isa<VectorType>()) 343 return VectorType::get(getShape(), elementType); 344 345 llvm_unreachable("Unhandled ShapedType clone hit"); 346 } 347 348 Type ShapedType::getElementType() const { 349 return TypeSwitch<Type, Type>(*this) 350 .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType, 351 UnrankedMemRefType>([](auto ty) { return ty.getElementType(); }); 352 } 353 354 unsigned ShapedType::getElementTypeBitWidth() const { 355 return getElementType().getIntOrFloatBitWidth(); 356 } 357 358 int64_t ShapedType::getNumElements() const { 359 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 360 auto shape = getShape(); 361 int64_t num = 1; 362 for (auto dim : shape) { 363 num *= dim; 364 assert(num >= 0 && "integer overflow in element count computation"); 365 } 366 return num; 367 } 368 369 int64_t ShapedType::getRank() const { 370 assert(hasRank() && "cannot query rank of unranked shaped type"); 371 return getShape().size(); 372 } 373 374 bool ShapedType::hasRank() const { 375 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 376 } 377 378 int64_t ShapedType::getDimSize(unsigned idx) const { 379 assert(idx < getRank() && "invalid index for shaped type"); 380 return getShape()[idx]; 381 } 382 383 bool ShapedType::isDynamicDim(unsigned idx) const { 384 assert(idx < getRank() && "invalid index for shaped type"); 385 return isDynamic(getShape()[idx]); 386 } 387 388 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 389 assert(index < getRank() && "invalid index"); 390 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 391 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 392 } 393 394 /// Get the number of bits require to store a value of the given shaped type. 395 /// Compute the value recursively since tensors are allowed to have vectors as 396 /// elements. 397 int64_t ShapedType::getSizeInBits() const { 398 assert(hasStaticShape() && 399 "cannot get the bit size of an aggregate with a dynamic shape"); 400 401 auto elementType = getElementType(); 402 if (elementType.isIntOrFloat()) 403 return elementType.getIntOrFloatBitWidth() * getNumElements(); 404 405 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 406 elementType = complexType.getElementType(); 407 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 408 } 409 410 // Tensors can have vectors and other tensors as elements, other shaped types 411 // cannot. 412 assert(isa<TensorType>() && "unsupported element type"); 413 assert((elementType.isa<VectorType, TensorType>()) && 414 "unsupported tensor element type"); 415 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 416 } 417 418 ArrayRef<int64_t> ShapedType::getShape() const { 419 if (auto vectorType = dyn_cast<VectorType>()) 420 return vectorType.getShape(); 421 if (auto tensorType = dyn_cast<RankedTensorType>()) 422 return tensorType.getShape(); 423 return cast<MemRefType>().getShape(); 424 } 425 426 int64_t ShapedType::getNumDynamicDims() const { 427 return llvm::count_if(getShape(), isDynamic); 428 } 429 430 bool ShapedType::hasStaticShape() const { 431 return hasRank() && llvm::none_of(getShape(), isDynamic); 432 } 433 434 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 435 return hasStaticShape() && getShape() == shape; 436 } 437 438 //===----------------------------------------------------------------------===// 439 // VectorType 440 //===----------------------------------------------------------------------===// 441 442 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 443 ArrayRef<int64_t> shape, Type elementType) { 444 if (!isValidElementType(elementType)) 445 return emitError() 446 << "vector elements must be int/index/float type but got " 447 << elementType; 448 449 if (any_of(shape, [](int64_t i) { return i <= 0; })) 450 return emitError() 451 << "vector types must have positive constant sizes but got " 452 << shape; 453 454 return success(); 455 } 456 457 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 458 if (!scale) 459 return VectorType(); 460 if (auto et = getElementType().dyn_cast<IntegerType>()) 461 if (auto scaledEt = et.scaleElementBitwidth(scale)) 462 return VectorType::get(getShape(), scaledEt); 463 if (auto et = getElementType().dyn_cast<FloatType>()) 464 if (auto scaledEt = et.scaleElementBitwidth(scale)) 465 return VectorType::get(getShape(), scaledEt); 466 return VectorType(); 467 } 468 469 void VectorType::walkImmediateSubElements( 470 function_ref<void(Attribute)> walkAttrsFn, 471 function_ref<void(Type)> walkTypesFn) const { 472 walkTypesFn(getElementType()); 473 } 474 475 //===----------------------------------------------------------------------===// 476 // TensorType 477 //===----------------------------------------------------------------------===// 478 479 // Check if "elementType" can be an element type of a tensor. 480 static LogicalResult 481 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 482 Type elementType) { 483 if (!TensorType::isValidElementType(elementType)) 484 return emitError() << "invalid tensor element type: " << elementType; 485 return success(); 486 } 487 488 /// Return true if the specified element type is ok in a tensor. 489 bool TensorType::isValidElementType(Type type) { 490 // Note: Non standard/builtin types are allowed to exist within tensor 491 // types. Dialects are expected to verify that tensor types have a valid 492 // element type within that dialect. 493 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 494 IndexType>() || 495 !llvm::isa<BuiltinDialect>(type.getDialect()); 496 } 497 498 //===----------------------------------------------------------------------===// 499 // RankedTensorType 500 //===----------------------------------------------------------------------===// 501 502 LogicalResult 503 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 504 ArrayRef<int64_t> shape, Type elementType, 505 Attribute encoding) { 506 for (int64_t s : shape) 507 if (s < -1) 508 return emitError() << "invalid tensor dimension size"; 509 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 510 if (failed(v.verifyEncoding(shape, elementType, emitError))) 511 return failure(); 512 return checkTensorElementType(emitError, elementType); 513 } 514 515 void RankedTensorType::walkImmediateSubElements( 516 function_ref<void(Attribute)> walkAttrsFn, 517 function_ref<void(Type)> walkTypesFn) const { 518 walkTypesFn(getElementType()); 519 if (Attribute encoding = getEncoding()) 520 walkAttrsFn(encoding); 521 } 522 523 //===----------------------------------------------------------------------===// 524 // UnrankedTensorType 525 //===----------------------------------------------------------------------===// 526 527 LogicalResult 528 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 529 Type elementType) { 530 return checkTensorElementType(emitError, elementType); 531 } 532 533 void UnrankedTensorType::walkImmediateSubElements( 534 function_ref<void(Attribute)> walkAttrsFn, 535 function_ref<void(Type)> walkTypesFn) const { 536 walkTypesFn(getElementType()); 537 } 538 539 //===----------------------------------------------------------------------===// 540 // BaseMemRefType 541 //===----------------------------------------------------------------------===// 542 543 Attribute BaseMemRefType::getMemorySpace() const { 544 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 545 return rankedMemRefTy.getMemorySpace(); 546 return cast<UnrankedMemRefType>().getMemorySpace(); 547 } 548 549 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 550 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 551 return rankedMemRefTy.getMemorySpaceAsInt(); 552 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 553 } 554 555 //===----------------------------------------------------------------------===// 556 // MemRefType 557 //===----------------------------------------------------------------------===// 558 559 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 560 /// `originalShape` with some `1` entries erased, return the set of indices 561 /// that specifies which of the entries of `originalShape` are dropped to obtain 562 /// `reducedShape`. The returned mask can be applied as a projection to 563 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 564 /// which dimensions must be kept when e.g. compute MemRef strides under 565 /// rank-reducing operations. Return None if reducedShape cannot be obtained 566 /// by dropping only `1` entries in `originalShape`. 567 llvm::Optional<llvm::SmallDenseSet<unsigned>> 568 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 569 ArrayRef<int64_t> reducedShape) { 570 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 571 llvm::SmallDenseSet<unsigned> unusedDims; 572 unsigned reducedIdx = 0; 573 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 574 // Greedily insert `originalIdx` if match. 575 if (reducedIdx < reducedRank && 576 originalShape[originalIdx] == reducedShape[reducedIdx]) { 577 reducedIdx++; 578 continue; 579 } 580 581 unusedDims.insert(originalIdx); 582 // If no match on `originalIdx`, the `originalShape` at this dimension 583 // must be 1, otherwise we bail. 584 if (originalShape[originalIdx] != 1) 585 return llvm::None; 586 } 587 // The whole reducedShape must be scanned, otherwise we bail. 588 if (reducedIdx != reducedRank) 589 return llvm::None; 590 return unusedDims; 591 } 592 593 SliceVerificationResult 594 mlir::isRankReducedType(ShapedType originalType, 595 ShapedType candidateReducedType) { 596 if (originalType == candidateReducedType) 597 return SliceVerificationResult::Success; 598 599 ShapedType originalShapedType = originalType.cast<ShapedType>(); 600 ShapedType candidateReducedShapedType = 601 candidateReducedType.cast<ShapedType>(); 602 603 // Rank and size logic is valid for all ShapedTypes. 604 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 605 ArrayRef<int64_t> candidateReducedShape = 606 candidateReducedShapedType.getShape(); 607 unsigned originalRank = originalShape.size(), 608 candidateReducedRank = candidateReducedShape.size(); 609 if (candidateReducedRank > originalRank) 610 return SliceVerificationResult::RankTooLarge; 611 612 auto optionalUnusedDimsMask = 613 computeRankReductionMask(originalShape, candidateReducedShape); 614 615 // Sizes cannot be matched in case empty vector is returned. 616 if (!optionalUnusedDimsMask.hasValue()) 617 return SliceVerificationResult::SizeMismatch; 618 619 if (originalShapedType.getElementType() != 620 candidateReducedShapedType.getElementType()) 621 return SliceVerificationResult::ElemTypeMismatch; 622 623 return SliceVerificationResult::Success; 624 } 625 626 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 627 // Empty attribute is allowed as default memory space. 628 if (!memorySpace) 629 return true; 630 631 // Supported built-in attributes. 632 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 633 return true; 634 635 // Allow custom dialect attributes. 636 if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect())) 637 return true; 638 639 return false; 640 } 641 642 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 643 MLIRContext *ctx) { 644 if (memorySpace == 0) 645 return nullptr; 646 647 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 648 } 649 650 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 651 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 652 if (intMemorySpace && intMemorySpace.getValue() == 0) 653 return nullptr; 654 655 return memorySpace; 656 } 657 658 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 659 if (!memorySpace) 660 return 0; 661 662 assert(memorySpace.isa<IntegerAttr>() && 663 "Using `getMemorySpaceInteger` with non-Integer attribute"); 664 665 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 666 } 667 668 MemRefType::Builder & 669 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 670 memorySpace = 671 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 672 return *this; 673 } 674 675 unsigned MemRefType::getMemorySpaceAsInt() const { 676 return detail::getMemorySpaceAsInt(getMemorySpace()); 677 } 678 679 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 680 MemRefLayoutAttrInterface layout, 681 Attribute memorySpace) { 682 // Use default layout for empty attribute. 683 if (!layout) 684 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 685 shape.size(), elementType.getContext())); 686 687 // Drop default memory space value and replace it with empty attribute. 688 memorySpace = skipDefaultMemorySpace(memorySpace); 689 690 return Base::get(elementType.getContext(), shape, elementType, layout, 691 memorySpace); 692 } 693 694 MemRefType MemRefType::getChecked( 695 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 696 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 697 698 // Use default layout for empty attribute. 699 if (!layout) 700 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 701 shape.size(), elementType.getContext())); 702 703 // Drop default memory space value and replace it with empty attribute. 704 memorySpace = skipDefaultMemorySpace(memorySpace); 705 706 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 707 elementType, layout, memorySpace); 708 } 709 710 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 711 AffineMap map, Attribute memorySpace) { 712 713 // Use default layout for empty map. 714 if (!map) 715 map = AffineMap::getMultiDimIdentityMap(shape.size(), 716 elementType.getContext()); 717 718 // Wrap AffineMap into Attribute. 719 Attribute layout = AffineMapAttr::get(map); 720 721 // Drop default memory space value and replace it with empty attribute. 722 memorySpace = skipDefaultMemorySpace(memorySpace); 723 724 return Base::get(elementType.getContext(), shape, elementType, layout, 725 memorySpace); 726 } 727 728 MemRefType 729 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 730 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 731 Attribute memorySpace) { 732 733 // Use default layout for empty map. 734 if (!map) 735 map = AffineMap::getMultiDimIdentityMap(shape.size(), 736 elementType.getContext()); 737 738 // Wrap AffineMap into Attribute. 739 Attribute layout = AffineMapAttr::get(map); 740 741 // Drop default memory space value and replace it with empty attribute. 742 memorySpace = skipDefaultMemorySpace(memorySpace); 743 744 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 745 elementType, layout, memorySpace); 746 } 747 748 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 749 AffineMap map, unsigned memorySpaceInd) { 750 751 // Use default layout for empty map. 752 if (!map) 753 map = AffineMap::getMultiDimIdentityMap(shape.size(), 754 elementType.getContext()); 755 756 // Wrap AffineMap into Attribute. 757 Attribute layout = AffineMapAttr::get(map); 758 759 // Convert deprecated integer-like memory space to Attribute. 760 Attribute memorySpace = 761 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 762 763 return Base::get(elementType.getContext(), shape, elementType, layout, 764 memorySpace); 765 } 766 767 MemRefType 768 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 769 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 770 unsigned memorySpaceInd) { 771 772 // Use default layout for empty map. 773 if (!map) 774 map = AffineMap::getMultiDimIdentityMap(shape.size(), 775 elementType.getContext()); 776 777 // Wrap AffineMap into Attribute. 778 Attribute layout = AffineMapAttr::get(map); 779 780 // Convert deprecated integer-like memory space to Attribute. 781 Attribute memorySpace = 782 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 783 784 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 785 elementType, layout, memorySpace); 786 } 787 788 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 789 ArrayRef<int64_t> shape, Type elementType, 790 MemRefLayoutAttrInterface layout, 791 Attribute memorySpace) { 792 if (!BaseMemRefType::isValidElementType(elementType)) 793 return emitError() << "invalid memref element type"; 794 795 // Negative sizes are not allowed except for `-1` that means dynamic size. 796 for (int64_t s : shape) 797 if (s < -1) 798 return emitError() << "invalid memref size"; 799 800 assert(layout && "missing layout specification"); 801 if (failed(layout.verifyLayout(shape, emitError))) 802 return failure(); 803 804 if (!isSupportedMemorySpace(memorySpace)) 805 return emitError() << "unsupported memory space Attribute"; 806 807 return success(); 808 } 809 810 void MemRefType::walkImmediateSubElements( 811 function_ref<void(Attribute)> walkAttrsFn, 812 function_ref<void(Type)> walkTypesFn) const { 813 walkTypesFn(getElementType()); 814 if (!getLayout().isIdentity()) 815 walkAttrsFn(getLayout()); 816 walkAttrsFn(getMemorySpace()); 817 } 818 819 //===----------------------------------------------------------------------===// 820 // UnrankedMemRefType 821 //===----------------------------------------------------------------------===// 822 823 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 824 return detail::getMemorySpaceAsInt(getMemorySpace()); 825 } 826 827 LogicalResult 828 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 829 Type elementType, Attribute memorySpace) { 830 if (!BaseMemRefType::isValidElementType(elementType)) 831 return emitError() << "invalid memref element type"; 832 833 if (!isSupportedMemorySpace(memorySpace)) 834 return emitError() << "unsupported memory space Attribute"; 835 836 return success(); 837 } 838 839 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 840 // i.e. single term). Accumulate the AffineExpr into the existing one. 841 static void extractStridesFromTerm(AffineExpr e, 842 AffineExpr multiplicativeFactor, 843 MutableArrayRef<AffineExpr> strides, 844 AffineExpr &offset) { 845 if (auto dim = e.dyn_cast<AffineDimExpr>()) 846 strides[dim.getPosition()] = 847 strides[dim.getPosition()] + multiplicativeFactor; 848 else 849 offset = offset + e * multiplicativeFactor; 850 } 851 852 /// Takes a single AffineExpr `e` and populates the `strides` array with the 853 /// strides expressions for each dim position. 854 /// The convention is that the strides for dimensions d0, .. dn appear in 855 /// order to make indexing intuitive into the result. 856 static LogicalResult extractStrides(AffineExpr e, 857 AffineExpr multiplicativeFactor, 858 MutableArrayRef<AffineExpr> strides, 859 AffineExpr &offset) { 860 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 861 if (!bin) { 862 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 863 return success(); 864 } 865 866 if (bin.getKind() == AffineExprKind::CeilDiv || 867 bin.getKind() == AffineExprKind::FloorDiv || 868 bin.getKind() == AffineExprKind::Mod) 869 return failure(); 870 871 if (bin.getKind() == AffineExprKind::Mul) { 872 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 873 if (dim) { 874 strides[dim.getPosition()] = 875 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 876 return success(); 877 } 878 // LHS and RHS may both contain complex expressions of dims. Try one path 879 // and if it fails try the other. This is guaranteed to succeed because 880 // only one path may have a `dim`, otherwise this is not an AffineExpr in 881 // the first place. 882 if (bin.getLHS().isSymbolicOrConstant()) 883 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 884 strides, offset); 885 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 886 strides, offset); 887 } 888 889 if (bin.getKind() == AffineExprKind::Add) { 890 auto res1 = 891 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 892 auto res2 = 893 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 894 return success(succeeded(res1) && succeeded(res2)); 895 } 896 897 llvm_unreachable("unexpected binary operation"); 898 } 899 900 LogicalResult mlir::getStridesAndOffset(MemRefType t, 901 SmallVectorImpl<AffineExpr> &strides, 902 AffineExpr &offset) { 903 AffineMap m = t.getLayout().getAffineMap(); 904 905 if (m.getNumResults() != 1 && !m.isIdentity()) 906 return failure(); 907 908 auto zero = getAffineConstantExpr(0, t.getContext()); 909 auto one = getAffineConstantExpr(1, t.getContext()); 910 offset = zero; 911 strides.assign(t.getRank(), zero); 912 913 // Canonical case for empty map. 914 if (m.isIdentity()) { 915 // 0-D corner case, offset is already 0. 916 if (t.getRank() == 0) 917 return success(); 918 auto stridedExpr = 919 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 920 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 921 return success(); 922 assert(false && "unexpected failure: extract strides in canonical layout"); 923 } 924 925 // Non-canonical case requires more work. 926 auto stridedExpr = 927 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 928 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 929 offset = AffineExpr(); 930 strides.clear(); 931 return failure(); 932 } 933 934 // Simplify results to allow folding to constants and simple checks. 935 unsigned numDims = m.getNumDims(); 936 unsigned numSymbols = m.getNumSymbols(); 937 offset = simplifyAffineExpr(offset, numDims, numSymbols); 938 for (auto &stride : strides) 939 stride = simplifyAffineExpr(stride, numDims, numSymbols); 940 941 /// In practice, a strided memref must be internally non-aliasing. Test 942 /// against 0 as a proxy. 943 /// TODO: static cases can have more advanced checks. 944 /// TODO: dynamic cases would require a way to compare symbolic 945 /// expressions and would probably need an affine set context propagated 946 /// everywhere. 947 if (llvm::any_of(strides, [](AffineExpr e) { 948 return e == getAffineConstantExpr(0, e.getContext()); 949 })) { 950 offset = AffineExpr(); 951 strides.clear(); 952 return failure(); 953 } 954 955 return success(); 956 } 957 958 LogicalResult mlir::getStridesAndOffset(MemRefType t, 959 SmallVectorImpl<int64_t> &strides, 960 int64_t &offset) { 961 AffineExpr offsetExpr; 962 SmallVector<AffineExpr, 4> strideExprs; 963 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 964 return failure(); 965 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 966 offset = cst.getValue(); 967 else 968 offset = ShapedType::kDynamicStrideOrOffset; 969 for (auto e : strideExprs) { 970 if (auto c = e.dyn_cast<AffineConstantExpr>()) 971 strides.push_back(c.getValue()); 972 else 973 strides.push_back(ShapedType::kDynamicStrideOrOffset); 974 } 975 return success(); 976 } 977 978 void UnrankedMemRefType::walkImmediateSubElements( 979 function_ref<void(Attribute)> walkAttrsFn, 980 function_ref<void(Type)> walkTypesFn) const { 981 walkTypesFn(getElementType()); 982 walkAttrsFn(getMemorySpace()); 983 } 984 985 //===----------------------------------------------------------------------===// 986 /// TupleType 987 //===----------------------------------------------------------------------===// 988 989 /// Return the elements types for this tuple. 990 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 991 992 /// Accumulate the types contained in this tuple and tuples nested within it. 993 /// Note that this only flattens nested tuples, not any other container type, 994 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 995 /// (i32, tensor<i32>, f32, i64) 996 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 997 for (Type type : getTypes()) { 998 if (auto nestedTuple = type.dyn_cast<TupleType>()) 999 nestedTuple.getFlattenedTypes(types); 1000 else 1001 types.push_back(type); 1002 } 1003 } 1004 1005 /// Return the number of element types. 1006 size_t TupleType::size() const { return getImpl()->size(); } 1007 1008 void TupleType::walkImmediateSubElements( 1009 function_ref<void(Attribute)> walkAttrsFn, 1010 function_ref<void(Type)> walkTypesFn) const { 1011 for (Type type : getTypes()) 1012 walkTypesFn(type); 1013 } 1014 1015 //===----------------------------------------------------------------------===// 1016 // Type Utilities 1017 //===----------------------------------------------------------------------===// 1018 1019 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 1020 int64_t offset, 1021 MLIRContext *context) { 1022 AffineExpr expr; 1023 unsigned nSymbols = 0; 1024 1025 // AffineExpr for offset. 1026 // Static case. 1027 if (offset != MemRefType::getDynamicStrideOrOffset()) { 1028 auto cst = getAffineConstantExpr(offset, context); 1029 expr = cst; 1030 } else { 1031 // Dynamic case, new symbol for the offset. 1032 auto sym = getAffineSymbolExpr(nSymbols++, context); 1033 expr = sym; 1034 } 1035 1036 // AffineExpr for strides. 1037 for (auto en : llvm::enumerate(strides)) { 1038 auto dim = en.index(); 1039 auto stride = en.value(); 1040 assert(stride != 0 && "Invalid stride specification"); 1041 auto d = getAffineDimExpr(dim, context); 1042 AffineExpr mult; 1043 // Static case. 1044 if (stride != MemRefType::getDynamicStrideOrOffset()) 1045 mult = getAffineConstantExpr(stride, context); 1046 else 1047 // Dynamic case, new symbol for each new stride. 1048 mult = getAffineSymbolExpr(nSymbols++, context); 1049 expr = expr + d * mult; 1050 } 1051 1052 return AffineMap::get(strides.size(), nSymbols, expr); 1053 } 1054 1055 /// Return a version of `t` with identity layout if it can be determined 1056 /// statically that the layout is the canonical contiguous strided layout. 1057 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 1058 /// `t` with simplified layout. 1059 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 1060 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 1061 AffineMap m = t.getLayout().getAffineMap(); 1062 1063 // Already in canonical form. 1064 if (m.isIdentity()) 1065 return t; 1066 1067 // Can't reduce to canonical identity form, return in canonical form. 1068 if (m.getNumResults() > 1) 1069 return t; 1070 1071 // Corner-case for 0-D affine maps. 1072 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 1073 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 1074 if (cst.getValue() == 0) 1075 return MemRefType::Builder(t).setLayout({}); 1076 return t; 1077 } 1078 1079 // 0-D corner case for empty shape that still have an affine map. Example: 1080 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 1081 // offset needs to remain, just return t. 1082 if (t.getShape().empty()) 1083 return t; 1084 1085 // If the canonical strided layout for the sizes of `t` is equal to the 1086 // simplified layout of `t` we can just return an empty layout. Otherwise, 1087 // just simplify the existing layout. 1088 AffineExpr expr = 1089 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 1090 auto simplifiedLayoutExpr = 1091 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 1092 if (expr != simplifiedLayoutExpr) 1093 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( 1094 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); 1095 return MemRefType::Builder(t).setLayout({}); 1096 } 1097 1098 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1099 ArrayRef<AffineExpr> exprs, 1100 MLIRContext *context) { 1101 assert(!sizes.empty() && !exprs.empty() && 1102 "expected non-empty sizes and exprs"); 1103 1104 // Size 0 corner case is useful for canonicalizations. 1105 if (llvm::is_contained(sizes, 0)) 1106 return getAffineConstantExpr(0, context); 1107 1108 auto maps = AffineMap::inferFromExprList(exprs); 1109 assert(!maps.empty() && "Expected one non-empty map"); 1110 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 1111 1112 AffineExpr expr; 1113 bool dynamicPoisonBit = false; 1114 int64_t runningSize = 1; 1115 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 1116 int64_t size = std::get<1>(en); 1117 // Degenerate case, no size =-> no stride 1118 if (size == 0) 1119 continue; 1120 AffineExpr dimExpr = std::get<0>(en); 1121 AffineExpr stride = dynamicPoisonBit 1122 ? getAffineSymbolExpr(nSymbols++, context) 1123 : getAffineConstantExpr(runningSize, context); 1124 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 1125 if (size > 0) { 1126 runningSize *= size; 1127 assert(runningSize > 0 && "integer overflow in size computation"); 1128 } else { 1129 dynamicPoisonBit = true; 1130 } 1131 } 1132 return simplifyAffineExpr(expr, numDims, nSymbols); 1133 } 1134 1135 /// Return a version of `t` with a layout that has all dynamic offset and 1136 /// strides. This is used to erase the static layout. 1137 MemRefType mlir::eraseStridedLayout(MemRefType t) { 1138 auto val = ShapedType::kDynamicStrideOrOffset; 1139 return MemRefType::Builder(t).setLayout( 1140 AffineMapAttr::get(makeStridedLinearLayoutMap( 1141 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()))); 1142 } 1143 1144 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1145 MLIRContext *context) { 1146 SmallVector<AffineExpr, 4> exprs; 1147 exprs.reserve(sizes.size()); 1148 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 1149 exprs.push_back(getAffineDimExpr(dim, context)); 1150 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 1151 } 1152 1153 /// Return true if the layout for `t` is compatible with strided semantics. 1154 bool mlir::isStrided(MemRefType t) { 1155 int64_t offset; 1156 SmallVector<int64_t, 4> strides; 1157 auto res = getStridesAndOffset(t, strides, offset); 1158 return succeeded(res); 1159 } 1160 1161 /// Return the layout map in strided linear layout AffineMap form. 1162 /// Return null if the layout is not compatible with a strided layout. 1163 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1164 int64_t offset; 1165 SmallVector<int64_t, 4> strides; 1166 if (failed(getStridesAndOffset(t, strides, offset))) 1167 return AffineMap(); 1168 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1169 } 1170