1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/TensorEncoding.h" 19 #include "llvm/ADT/APFloat.h" 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::detail; 27 28 //===----------------------------------------------------------------------===// 29 /// Tablegen Type Definitions 30 //===----------------------------------------------------------------------===// 31 32 #define GET_TYPEDEF_CLASSES 33 #include "mlir/IR/BuiltinTypes.cpp.inc" 34 35 //===----------------------------------------------------------------------===// 36 /// Tablegen Interface Definitions 37 //===----------------------------------------------------------------------===// 38 39 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" 40 41 //===----------------------------------------------------------------------===// 42 // BuiltinDialect 43 //===----------------------------------------------------------------------===// 44 45 void BuiltinDialect::registerTypes() { 46 addTypes< 47 #define GET_TYPEDEF_LIST 48 #include "mlir/IR/BuiltinTypes.cpp.inc" 49 >(); 50 } 51 52 //===----------------------------------------------------------------------===// 53 /// ComplexType 54 //===----------------------------------------------------------------------===// 55 56 /// Verify the construction of an integer type. 57 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 58 Type elementType) { 59 if (!elementType.isIntOrFloat()) 60 return emitError() << "invalid element type for complex"; 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // Integer Type 66 //===----------------------------------------------------------------------===// 67 68 // static constexpr must have a definition (until in C++17 and inline variable). 69 constexpr unsigned IntegerType::kMaxWidth; 70 71 /// Verify the construction of an integer type. 72 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 73 unsigned width, 74 SignednessSemantics signedness) { 75 if (width > IntegerType::kMaxWidth) { 76 return emitError() << "integer bitwidth is limited to " 77 << IntegerType::kMaxWidth << " bits"; 78 } 79 return success(); 80 } 81 82 unsigned IntegerType::getWidth() const { return getImpl()->width; } 83 84 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 85 return getImpl()->signedness; 86 } 87 88 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 89 if (!scale) 90 return IntegerType(); 91 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Float Type 96 //===----------------------------------------------------------------------===// 97 98 unsigned FloatType::getWidth() { 99 if (isa<Float16Type, BFloat16Type>()) 100 return 16; 101 if (isa<Float32Type>()) 102 return 32; 103 if (isa<Float64Type>()) 104 return 64; 105 if (isa<Float80Type>()) 106 return 80; 107 if (isa<Float128Type>()) 108 return 128; 109 llvm_unreachable("unexpected float type"); 110 } 111 112 /// Returns the floating semantics for the given type. 113 const llvm::fltSemantics &FloatType::getFloatSemantics() { 114 if (isa<BFloat16Type>()) 115 return APFloat::BFloat(); 116 if (isa<Float16Type>()) 117 return APFloat::IEEEhalf(); 118 if (isa<Float32Type>()) 119 return APFloat::IEEEsingle(); 120 if (isa<Float64Type>()) 121 return APFloat::IEEEdouble(); 122 if (isa<Float80Type>()) 123 return APFloat::x87DoubleExtended(); 124 if (isa<Float128Type>()) 125 return APFloat::IEEEquad(); 126 llvm_unreachable("non-floating point type used"); 127 } 128 129 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 130 if (!scale) 131 return FloatType(); 132 MLIRContext *ctx = getContext(); 133 if (isF16() || isBF16()) { 134 if (scale == 2) 135 return FloatType::getF32(ctx); 136 if (scale == 4) 137 return FloatType::getF64(ctx); 138 } 139 if (isF32()) 140 if (scale == 2) 141 return FloatType::getF64(ctx); 142 return FloatType(); 143 } 144 145 //===----------------------------------------------------------------------===// 146 // FunctionType 147 //===----------------------------------------------------------------------===// 148 149 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 150 151 ArrayRef<Type> FunctionType::getInputs() const { 152 return getImpl()->getInputs(); 153 } 154 155 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 156 157 ArrayRef<Type> FunctionType::getResults() const { 158 return getImpl()->getResults(); 159 } 160 161 /// Helper to call a callback once on each index in the range 162 /// [0, `totalIndices`), *except* for the indices given in `indices`. 163 /// `indices` is allowed to have duplicates and can be in any order. 164 inline void iterateIndicesExcept(unsigned totalIndices, 165 ArrayRef<unsigned> indices, 166 function_ref<void(unsigned)> callback) { 167 llvm::BitVector skipIndices(totalIndices); 168 for (unsigned i : indices) 169 skipIndices.set(i); 170 171 for (unsigned i = 0; i < totalIndices; ++i) 172 if (!skipIndices.test(i)) 173 callback(i); 174 } 175 176 /// Returns a new function type with the specified arguments and results 177 /// inserted. 178 FunctionType FunctionType::getWithArgsAndResults( 179 ArrayRef<unsigned> argIndices, TypeRange argTypes, 180 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 181 assert(argIndices.size() == argTypes.size()); 182 assert(resultIndices.size() == resultTypes.size()); 183 184 ArrayRef<Type> newInputTypes = getInputs(); 185 SmallVector<Type, 4> newInputTypesBuffer; 186 if (!argIndices.empty()) { 187 const auto *fromIt = newInputTypes.begin(); 188 for (auto it : llvm::zip(argIndices, argTypes)) { 189 const auto *toIt = newInputTypes.begin() + std::get<0>(it); 190 newInputTypesBuffer.append(fromIt, toIt); 191 newInputTypesBuffer.push_back(std::get<1>(it)); 192 fromIt = toIt; 193 } 194 newInputTypesBuffer.append(fromIt, newInputTypes.end()); 195 newInputTypes = newInputTypesBuffer; 196 } 197 198 ArrayRef<Type> newResultTypes = getResults(); 199 SmallVector<Type, 4> newResultTypesBuffer; 200 if (!resultIndices.empty()) { 201 const auto *fromIt = newResultTypes.begin(); 202 for (auto it : llvm::zip(resultIndices, resultTypes)) { 203 const auto *toIt = newResultTypes.begin() + std::get<0>(it); 204 newResultTypesBuffer.append(fromIt, toIt); 205 newResultTypesBuffer.push_back(std::get<1>(it)); 206 fromIt = toIt; 207 } 208 newResultTypesBuffer.append(fromIt, newResultTypes.end()); 209 newResultTypes = newResultTypesBuffer; 210 } 211 212 return FunctionType::get(getContext(), newInputTypes, newResultTypes); 213 } 214 215 /// Returns a new function type without the specified arguments and results. 216 FunctionType 217 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 218 ArrayRef<unsigned> resultIndices) { 219 ArrayRef<Type> newInputTypes = getInputs(); 220 SmallVector<Type, 4> newInputTypesBuffer; 221 if (!argIndices.empty()) { 222 unsigned originalNumArgs = getNumInputs(); 223 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 224 newInputTypesBuffer.emplace_back(getInput(i)); 225 }); 226 newInputTypes = newInputTypesBuffer; 227 } 228 229 ArrayRef<Type> newResultTypes = getResults(); 230 SmallVector<Type, 4> newResultTypesBuffer; 231 if (!resultIndices.empty()) { 232 unsigned originalNumResults = getNumResults(); 233 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 234 newResultTypesBuffer.emplace_back(getResult(i)); 235 }); 236 newResultTypes = newResultTypesBuffer; 237 } 238 239 return get(getContext(), newInputTypes, newResultTypes); 240 } 241 242 void FunctionType::walkImmediateSubElements( 243 function_ref<void(Attribute)> walkAttrsFn, 244 function_ref<void(Type)> walkTypesFn) const { 245 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 246 walkTypesFn(type); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // OpaqueType 251 //===----------------------------------------------------------------------===// 252 253 /// Verify the construction of an opaque type. 254 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 255 StringAttr dialect, StringRef typeData) { 256 if (!Dialect::isValidNamespace(dialect.strref())) 257 return emitError() << "invalid dialect namespace '" << dialect << "'"; 258 259 // Check that the dialect is actually registered. 260 MLIRContext *context = dialect.getContext(); 261 if (!context->allowsUnregisteredDialects() && 262 !context->getLoadedDialect(dialect.strref())) { 263 return emitError() 264 << "`!" << dialect << "<\"" << typeData << "\">" 265 << "` type created with unregistered dialect. If this is " 266 "intended, please call allowUnregisteredDialects() on the " 267 "MLIRContext, or use -allow-unregistered-dialect with " 268 "the MLIR opt tool used"; 269 } 270 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // ShapedType 276 //===----------------------------------------------------------------------===// 277 constexpr int64_t ShapedType::kDynamicSize; 278 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 279 280 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 281 if (auto other = dyn_cast<MemRefType>()) { 282 MemRefType::Builder b(other); 283 b.setShape(shape); 284 b.setElementType(elementType); 285 return b; 286 } 287 288 if (auto other = dyn_cast<UnrankedMemRefType>()) { 289 MemRefType::Builder b(shape, elementType); 290 b.setMemorySpace(other.getMemorySpace()); 291 return b; 292 } 293 294 if (isa<TensorType>()) 295 return RankedTensorType::get(shape, elementType); 296 297 if (isa<VectorType>()) 298 return VectorType::get(shape, elementType); 299 300 llvm_unreachable("Unhandled ShapedType clone case"); 301 } 302 303 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 304 if (auto other = dyn_cast<MemRefType>()) { 305 MemRefType::Builder b(other); 306 b.setShape(shape); 307 return b; 308 } 309 310 if (auto other = dyn_cast<UnrankedMemRefType>()) { 311 MemRefType::Builder b(shape, other.getElementType()); 312 b.setShape(shape); 313 b.setMemorySpace(other.getMemorySpace()); 314 return b; 315 } 316 317 if (isa<TensorType>()) 318 return RankedTensorType::get(shape, getElementType()); 319 320 if (isa<VectorType>()) 321 return VectorType::get(shape, getElementType()); 322 323 llvm_unreachable("Unhandled ShapedType clone case"); 324 } 325 326 ShapedType ShapedType::clone(Type elementType) { 327 if (auto other = dyn_cast<MemRefType>()) { 328 MemRefType::Builder b(other); 329 b.setElementType(elementType); 330 return b; 331 } 332 333 if (auto other = dyn_cast<UnrankedMemRefType>()) { 334 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 335 } 336 337 if (isa<TensorType>()) { 338 if (hasRank()) 339 return RankedTensorType::get(getShape(), elementType); 340 return UnrankedTensorType::get(elementType); 341 } 342 343 if (isa<VectorType>()) 344 return VectorType::get(getShape(), elementType); 345 346 llvm_unreachable("Unhandled ShapedType clone hit"); 347 } 348 349 Type ShapedType::getElementType() const { 350 return TypeSwitch<Type, Type>(*this) 351 .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType, 352 UnrankedMemRefType>([](auto ty) { return ty.getElementType(); }); 353 } 354 355 unsigned ShapedType::getElementTypeBitWidth() const { 356 return getElementType().getIntOrFloatBitWidth(); 357 } 358 359 int64_t ShapedType::getNumElements() const { 360 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 361 auto shape = getShape(); 362 int64_t num = 1; 363 for (auto dim : shape) { 364 num *= dim; 365 assert(num >= 0 && "integer overflow in element count computation"); 366 } 367 return num; 368 } 369 370 int64_t ShapedType::getRank() const { 371 assert(hasRank() && "cannot query rank of unranked shaped type"); 372 return getShape().size(); 373 } 374 375 bool ShapedType::hasRank() const { 376 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 377 } 378 379 int64_t ShapedType::getDimSize(unsigned idx) const { 380 assert(idx < getRank() && "invalid index for shaped type"); 381 return getShape()[idx]; 382 } 383 384 bool ShapedType::isDynamicDim(unsigned idx) const { 385 assert(idx < getRank() && "invalid index for shaped type"); 386 return isDynamic(getShape()[idx]); 387 } 388 389 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 390 assert(index < getRank() && "invalid index"); 391 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 392 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 393 } 394 395 /// Get the number of bits require to store a value of the given shaped type. 396 /// Compute the value recursively since tensors are allowed to have vectors as 397 /// elements. 398 int64_t ShapedType::getSizeInBits() const { 399 assert(hasStaticShape() && 400 "cannot get the bit size of an aggregate with a dynamic shape"); 401 402 auto elementType = getElementType(); 403 if (elementType.isIntOrFloat()) 404 return elementType.getIntOrFloatBitWidth() * getNumElements(); 405 406 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 407 elementType = complexType.getElementType(); 408 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 409 } 410 411 // Tensors can have vectors and other tensors as elements, other shaped types 412 // cannot. 413 assert(isa<TensorType>() && "unsupported element type"); 414 assert((elementType.isa<VectorType, TensorType>()) && 415 "unsupported tensor element type"); 416 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 417 } 418 419 ArrayRef<int64_t> ShapedType::getShape() const { 420 if (auto vectorType = dyn_cast<VectorType>()) 421 return vectorType.getShape(); 422 if (auto tensorType = dyn_cast<RankedTensorType>()) 423 return tensorType.getShape(); 424 return cast<MemRefType>().getShape(); 425 } 426 427 int64_t ShapedType::getNumDynamicDims() const { 428 return llvm::count_if(getShape(), isDynamic); 429 } 430 431 bool ShapedType::hasStaticShape() const { 432 return hasRank() && llvm::none_of(getShape(), isDynamic); 433 } 434 435 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 436 return hasStaticShape() && getShape() == shape; 437 } 438 439 //===----------------------------------------------------------------------===// 440 // VectorType 441 //===----------------------------------------------------------------------===// 442 443 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 444 ArrayRef<int64_t> shape, Type elementType) { 445 if (!isValidElementType(elementType)) 446 return emitError() 447 << "vector elements must be int/index/float type but got " 448 << elementType; 449 450 if (any_of(shape, [](int64_t i) { return i <= 0; })) 451 return emitError() 452 << "vector types must have positive constant sizes but got " 453 << shape; 454 455 return success(); 456 } 457 458 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 459 if (!scale) 460 return VectorType(); 461 if (auto et = getElementType().dyn_cast<IntegerType>()) 462 if (auto scaledEt = et.scaleElementBitwidth(scale)) 463 return VectorType::get(getShape(), scaledEt); 464 if (auto et = getElementType().dyn_cast<FloatType>()) 465 if (auto scaledEt = et.scaleElementBitwidth(scale)) 466 return VectorType::get(getShape(), scaledEt); 467 return VectorType(); 468 } 469 470 void VectorType::walkImmediateSubElements( 471 function_ref<void(Attribute)> walkAttrsFn, 472 function_ref<void(Type)> walkTypesFn) const { 473 walkTypesFn(getElementType()); 474 } 475 476 //===----------------------------------------------------------------------===// 477 // TensorType 478 //===----------------------------------------------------------------------===// 479 480 // Check if "elementType" can be an element type of a tensor. 481 static LogicalResult 482 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 483 Type elementType) { 484 if (!TensorType::isValidElementType(elementType)) 485 return emitError() << "invalid tensor element type: " << elementType; 486 return success(); 487 } 488 489 /// Return true if the specified element type is ok in a tensor. 490 bool TensorType::isValidElementType(Type type) { 491 // Note: Non standard/builtin types are allowed to exist within tensor 492 // types. Dialects are expected to verify that tensor types have a valid 493 // element type within that dialect. 494 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 495 IndexType>() || 496 !llvm::isa<BuiltinDialect>(type.getDialect()); 497 } 498 499 //===----------------------------------------------------------------------===// 500 // RankedTensorType 501 //===----------------------------------------------------------------------===// 502 503 LogicalResult 504 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 505 ArrayRef<int64_t> shape, Type elementType, 506 Attribute encoding) { 507 for (int64_t s : shape) 508 if (s < -1) 509 return emitError() << "invalid tensor dimension size"; 510 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 511 if (failed(v.verifyEncoding(shape, elementType, emitError))) 512 return failure(); 513 return checkTensorElementType(emitError, elementType); 514 } 515 516 void RankedTensorType::walkImmediateSubElements( 517 function_ref<void(Attribute)> walkAttrsFn, 518 function_ref<void(Type)> walkTypesFn) const { 519 walkTypesFn(getElementType()); 520 if (Attribute encoding = getEncoding()) 521 walkAttrsFn(encoding); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // UnrankedTensorType 526 //===----------------------------------------------------------------------===// 527 528 LogicalResult 529 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 530 Type elementType) { 531 return checkTensorElementType(emitError, elementType); 532 } 533 534 void UnrankedTensorType::walkImmediateSubElements( 535 function_ref<void(Attribute)> walkAttrsFn, 536 function_ref<void(Type)> walkTypesFn) const { 537 walkTypesFn(getElementType()); 538 } 539 540 //===----------------------------------------------------------------------===// 541 // BaseMemRefType 542 //===----------------------------------------------------------------------===// 543 544 Attribute BaseMemRefType::getMemorySpace() const { 545 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 546 return rankedMemRefTy.getMemorySpace(); 547 return cast<UnrankedMemRefType>().getMemorySpace(); 548 } 549 550 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 551 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 552 return rankedMemRefTy.getMemorySpaceAsInt(); 553 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 554 } 555 556 //===----------------------------------------------------------------------===// 557 // MemRefType 558 //===----------------------------------------------------------------------===// 559 560 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 561 /// `originalShape` with some `1` entries erased, return the set of indices 562 /// that specifies which of the entries of `originalShape` are dropped to obtain 563 /// `reducedShape`. The returned mask can be applied as a projection to 564 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 565 /// which dimensions must be kept when e.g. compute MemRef strides under 566 /// rank-reducing operations. Return None if reducedShape cannot be obtained 567 /// by dropping only `1` entries in `originalShape`. 568 llvm::Optional<llvm::SmallDenseSet<unsigned>> 569 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 570 ArrayRef<int64_t> reducedShape) { 571 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 572 llvm::SmallDenseSet<unsigned> unusedDims; 573 unsigned reducedIdx = 0; 574 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 575 // Greedily insert `originalIdx` if match. 576 if (reducedIdx < reducedRank && 577 originalShape[originalIdx] == reducedShape[reducedIdx]) { 578 reducedIdx++; 579 continue; 580 } 581 582 unusedDims.insert(originalIdx); 583 // If no match on `originalIdx`, the `originalShape` at this dimension 584 // must be 1, otherwise we bail. 585 if (originalShape[originalIdx] != 1) 586 return llvm::None; 587 } 588 // The whole reducedShape must be scanned, otherwise we bail. 589 if (reducedIdx != reducedRank) 590 return llvm::None; 591 return unusedDims; 592 } 593 594 SliceVerificationResult 595 mlir::isRankReducedType(ShapedType originalType, 596 ShapedType candidateReducedType) { 597 if (originalType == candidateReducedType) 598 return SliceVerificationResult::Success; 599 600 ShapedType originalShapedType = originalType.cast<ShapedType>(); 601 ShapedType candidateReducedShapedType = 602 candidateReducedType.cast<ShapedType>(); 603 604 // Rank and size logic is valid for all ShapedTypes. 605 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 606 ArrayRef<int64_t> candidateReducedShape = 607 candidateReducedShapedType.getShape(); 608 unsigned originalRank = originalShape.size(), 609 candidateReducedRank = candidateReducedShape.size(); 610 if (candidateReducedRank > originalRank) 611 return SliceVerificationResult::RankTooLarge; 612 613 auto optionalUnusedDimsMask = 614 computeRankReductionMask(originalShape, candidateReducedShape); 615 616 // Sizes cannot be matched in case empty vector is returned. 617 if (!optionalUnusedDimsMask.hasValue()) 618 return SliceVerificationResult::SizeMismatch; 619 620 if (originalShapedType.getElementType() != 621 candidateReducedShapedType.getElementType()) 622 return SliceVerificationResult::ElemTypeMismatch; 623 624 return SliceVerificationResult::Success; 625 } 626 627 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 628 // Empty attribute is allowed as default memory space. 629 if (!memorySpace) 630 return true; 631 632 // Supported built-in attributes. 633 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 634 return true; 635 636 // Allow custom dialect attributes. 637 if (!isa<BuiltinDialect>(memorySpace.getDialect())) 638 return true; 639 640 return false; 641 } 642 643 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 644 MLIRContext *ctx) { 645 if (memorySpace == 0) 646 return nullptr; 647 648 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 649 } 650 651 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 652 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 653 if (intMemorySpace && intMemorySpace.getValue() == 0) 654 return nullptr; 655 656 return memorySpace; 657 } 658 659 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 660 if (!memorySpace) 661 return 0; 662 663 assert(memorySpace.isa<IntegerAttr>() && 664 "Using `getMemorySpaceInteger` with non-Integer attribute"); 665 666 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 667 } 668 669 MemRefType::Builder & 670 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 671 memorySpace = 672 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 673 return *this; 674 } 675 676 unsigned MemRefType::getMemorySpaceAsInt() const { 677 return detail::getMemorySpaceAsInt(getMemorySpace()); 678 } 679 680 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 681 MemRefLayoutAttrInterface layout, 682 Attribute memorySpace) { 683 // Use default layout for empty attribute. 684 if (!layout) 685 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 686 shape.size(), elementType.getContext())); 687 688 // Drop default memory space value and replace it with empty attribute. 689 memorySpace = skipDefaultMemorySpace(memorySpace); 690 691 return Base::get(elementType.getContext(), shape, elementType, layout, 692 memorySpace); 693 } 694 695 MemRefType MemRefType::getChecked( 696 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 697 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 698 699 // Use default layout for empty attribute. 700 if (!layout) 701 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 702 shape.size(), elementType.getContext())); 703 704 // Drop default memory space value and replace it with empty attribute. 705 memorySpace = skipDefaultMemorySpace(memorySpace); 706 707 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 708 elementType, layout, memorySpace); 709 } 710 711 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 712 AffineMap map, Attribute memorySpace) { 713 714 // Use default layout for empty map. 715 if (!map) 716 map = AffineMap::getMultiDimIdentityMap(shape.size(), 717 elementType.getContext()); 718 719 // Wrap AffineMap into Attribute. 720 Attribute layout = AffineMapAttr::get(map); 721 722 // Drop default memory space value and replace it with empty attribute. 723 memorySpace = skipDefaultMemorySpace(memorySpace); 724 725 return Base::get(elementType.getContext(), shape, elementType, layout, 726 memorySpace); 727 } 728 729 MemRefType 730 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 731 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 732 Attribute memorySpace) { 733 734 // Use default layout for empty map. 735 if (!map) 736 map = AffineMap::getMultiDimIdentityMap(shape.size(), 737 elementType.getContext()); 738 739 // Wrap AffineMap into Attribute. 740 Attribute layout = AffineMapAttr::get(map); 741 742 // Drop default memory space value and replace it with empty attribute. 743 memorySpace = skipDefaultMemorySpace(memorySpace); 744 745 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 746 elementType, layout, memorySpace); 747 } 748 749 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 750 AffineMap map, unsigned memorySpaceInd) { 751 752 // Use default layout for empty map. 753 if (!map) 754 map = AffineMap::getMultiDimIdentityMap(shape.size(), 755 elementType.getContext()); 756 757 // Wrap AffineMap into Attribute. 758 Attribute layout = AffineMapAttr::get(map); 759 760 // Convert deprecated integer-like memory space to Attribute. 761 Attribute memorySpace = 762 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 763 764 return Base::get(elementType.getContext(), shape, elementType, layout, 765 memorySpace); 766 } 767 768 MemRefType 769 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 770 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 771 unsigned memorySpaceInd) { 772 773 // Use default layout for empty map. 774 if (!map) 775 map = AffineMap::getMultiDimIdentityMap(shape.size(), 776 elementType.getContext()); 777 778 // Wrap AffineMap into Attribute. 779 Attribute layout = AffineMapAttr::get(map); 780 781 // Convert deprecated integer-like memory space to Attribute. 782 Attribute memorySpace = 783 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 784 785 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 786 elementType, layout, memorySpace); 787 } 788 789 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 790 ArrayRef<int64_t> shape, Type elementType, 791 MemRefLayoutAttrInterface layout, 792 Attribute memorySpace) { 793 if (!BaseMemRefType::isValidElementType(elementType)) 794 return emitError() << "invalid memref element type"; 795 796 // Negative sizes are not allowed except for `-1` that means dynamic size. 797 for (int64_t s : shape) 798 if (s < -1) 799 return emitError() << "invalid memref size"; 800 801 assert(layout && "missing layout specification"); 802 if (failed(layout.verifyLayout(shape, emitError))) 803 return failure(); 804 805 if (!isSupportedMemorySpace(memorySpace)) 806 return emitError() << "unsupported memory space Attribute"; 807 808 return success(); 809 } 810 811 void MemRefType::walkImmediateSubElements( 812 function_ref<void(Attribute)> walkAttrsFn, 813 function_ref<void(Type)> walkTypesFn) const { 814 walkTypesFn(getElementType()); 815 if (!getLayout().isIdentity()) 816 walkAttrsFn(getLayout()); 817 walkAttrsFn(getMemorySpace()); 818 } 819 820 //===----------------------------------------------------------------------===// 821 // UnrankedMemRefType 822 //===----------------------------------------------------------------------===// 823 824 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 825 return detail::getMemorySpaceAsInt(getMemorySpace()); 826 } 827 828 LogicalResult 829 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 830 Type elementType, Attribute memorySpace) { 831 if (!BaseMemRefType::isValidElementType(elementType)) 832 return emitError() << "invalid memref element type"; 833 834 if (!isSupportedMemorySpace(memorySpace)) 835 return emitError() << "unsupported memory space Attribute"; 836 837 return success(); 838 } 839 840 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 841 // i.e. single term). Accumulate the AffineExpr into the existing one. 842 static void extractStridesFromTerm(AffineExpr e, 843 AffineExpr multiplicativeFactor, 844 MutableArrayRef<AffineExpr> strides, 845 AffineExpr &offset) { 846 if (auto dim = e.dyn_cast<AffineDimExpr>()) 847 strides[dim.getPosition()] = 848 strides[dim.getPosition()] + multiplicativeFactor; 849 else 850 offset = offset + e * multiplicativeFactor; 851 } 852 853 /// Takes a single AffineExpr `e` and populates the `strides` array with the 854 /// strides expressions for each dim position. 855 /// The convention is that the strides for dimensions d0, .. dn appear in 856 /// order to make indexing intuitive into the result. 857 static LogicalResult extractStrides(AffineExpr e, 858 AffineExpr multiplicativeFactor, 859 MutableArrayRef<AffineExpr> strides, 860 AffineExpr &offset) { 861 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 862 if (!bin) { 863 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 864 return success(); 865 } 866 867 if (bin.getKind() == AffineExprKind::CeilDiv || 868 bin.getKind() == AffineExprKind::FloorDiv || 869 bin.getKind() == AffineExprKind::Mod) 870 return failure(); 871 872 if (bin.getKind() == AffineExprKind::Mul) { 873 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 874 if (dim) { 875 strides[dim.getPosition()] = 876 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 877 return success(); 878 } 879 // LHS and RHS may both contain complex expressions of dims. Try one path 880 // and if it fails try the other. This is guaranteed to succeed because 881 // only one path may have a `dim`, otherwise this is not an AffineExpr in 882 // the first place. 883 if (bin.getLHS().isSymbolicOrConstant()) 884 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 885 strides, offset); 886 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 887 strides, offset); 888 } 889 890 if (bin.getKind() == AffineExprKind::Add) { 891 auto res1 = 892 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 893 auto res2 = 894 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 895 return success(succeeded(res1) && succeeded(res2)); 896 } 897 898 llvm_unreachable("unexpected binary operation"); 899 } 900 901 LogicalResult mlir::getStridesAndOffset(MemRefType t, 902 SmallVectorImpl<AffineExpr> &strides, 903 AffineExpr &offset) { 904 AffineMap m = t.getLayout().getAffineMap(); 905 906 if (m.getNumResults() != 1 && !m.isIdentity()) 907 return failure(); 908 909 auto zero = getAffineConstantExpr(0, t.getContext()); 910 auto one = getAffineConstantExpr(1, t.getContext()); 911 offset = zero; 912 strides.assign(t.getRank(), zero); 913 914 // Canonical case for empty map. 915 if (m.isIdentity()) { 916 // 0-D corner case, offset is already 0. 917 if (t.getRank() == 0) 918 return success(); 919 auto stridedExpr = 920 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 921 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 922 return success(); 923 assert(false && "unexpected failure: extract strides in canonical layout"); 924 } 925 926 // Non-canonical case requires more work. 927 auto stridedExpr = 928 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 929 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 930 offset = AffineExpr(); 931 strides.clear(); 932 return failure(); 933 } 934 935 // Simplify results to allow folding to constants and simple checks. 936 unsigned numDims = m.getNumDims(); 937 unsigned numSymbols = m.getNumSymbols(); 938 offset = simplifyAffineExpr(offset, numDims, numSymbols); 939 for (auto &stride : strides) 940 stride = simplifyAffineExpr(stride, numDims, numSymbols); 941 942 /// In practice, a strided memref must be internally non-aliasing. Test 943 /// against 0 as a proxy. 944 /// TODO: static cases can have more advanced checks. 945 /// TODO: dynamic cases would require a way to compare symbolic 946 /// expressions and would probably need an affine set context propagated 947 /// everywhere. 948 if (llvm::any_of(strides, [](AffineExpr e) { 949 return e == getAffineConstantExpr(0, e.getContext()); 950 })) { 951 offset = AffineExpr(); 952 strides.clear(); 953 return failure(); 954 } 955 956 return success(); 957 } 958 959 LogicalResult mlir::getStridesAndOffset(MemRefType t, 960 SmallVectorImpl<int64_t> &strides, 961 int64_t &offset) { 962 AffineExpr offsetExpr; 963 SmallVector<AffineExpr, 4> strideExprs; 964 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 965 return failure(); 966 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 967 offset = cst.getValue(); 968 else 969 offset = ShapedType::kDynamicStrideOrOffset; 970 for (auto e : strideExprs) { 971 if (auto c = e.dyn_cast<AffineConstantExpr>()) 972 strides.push_back(c.getValue()); 973 else 974 strides.push_back(ShapedType::kDynamicStrideOrOffset); 975 } 976 return success(); 977 } 978 979 void UnrankedMemRefType::walkImmediateSubElements( 980 function_ref<void(Attribute)> walkAttrsFn, 981 function_ref<void(Type)> walkTypesFn) const { 982 walkTypesFn(getElementType()); 983 walkAttrsFn(getMemorySpace()); 984 } 985 986 //===----------------------------------------------------------------------===// 987 /// TupleType 988 //===----------------------------------------------------------------------===// 989 990 /// Return the elements types for this tuple. 991 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 992 993 /// Accumulate the types contained in this tuple and tuples nested within it. 994 /// Note that this only flattens nested tuples, not any other container type, 995 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 996 /// (i32, tensor<i32>, f32, i64) 997 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 998 for (Type type : getTypes()) { 999 if (auto nestedTuple = type.dyn_cast<TupleType>()) 1000 nestedTuple.getFlattenedTypes(types); 1001 else 1002 types.push_back(type); 1003 } 1004 } 1005 1006 /// Return the number of element types. 1007 size_t TupleType::size() const { return getImpl()->size(); } 1008 1009 void TupleType::walkImmediateSubElements( 1010 function_ref<void(Attribute)> walkAttrsFn, 1011 function_ref<void(Type)> walkTypesFn) const { 1012 for (Type type : getTypes()) 1013 walkTypesFn(type); 1014 } 1015 1016 //===----------------------------------------------------------------------===// 1017 // Type Utilities 1018 //===----------------------------------------------------------------------===// 1019 1020 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 1021 int64_t offset, 1022 MLIRContext *context) { 1023 AffineExpr expr; 1024 unsigned nSymbols = 0; 1025 1026 // AffineExpr for offset. 1027 // Static case. 1028 if (offset != MemRefType::getDynamicStrideOrOffset()) { 1029 auto cst = getAffineConstantExpr(offset, context); 1030 expr = cst; 1031 } else { 1032 // Dynamic case, new symbol for the offset. 1033 auto sym = getAffineSymbolExpr(nSymbols++, context); 1034 expr = sym; 1035 } 1036 1037 // AffineExpr for strides. 1038 for (auto en : llvm::enumerate(strides)) { 1039 auto dim = en.index(); 1040 auto stride = en.value(); 1041 assert(stride != 0 && "Invalid stride specification"); 1042 auto d = getAffineDimExpr(dim, context); 1043 AffineExpr mult; 1044 // Static case. 1045 if (stride != MemRefType::getDynamicStrideOrOffset()) 1046 mult = getAffineConstantExpr(stride, context); 1047 else 1048 // Dynamic case, new symbol for each new stride. 1049 mult = getAffineSymbolExpr(nSymbols++, context); 1050 expr = expr + d * mult; 1051 } 1052 1053 return AffineMap::get(strides.size(), nSymbols, expr); 1054 } 1055 1056 /// Return a version of `t` with identity layout if it can be determined 1057 /// statically that the layout is the canonical contiguous strided layout. 1058 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 1059 /// `t` with simplified layout. 1060 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 1061 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 1062 AffineMap m = t.getLayout().getAffineMap(); 1063 1064 // Already in canonical form. 1065 if (m.isIdentity()) 1066 return t; 1067 1068 // Can't reduce to canonical identity form, return in canonical form. 1069 if (m.getNumResults() > 1) 1070 return t; 1071 1072 // Corner-case for 0-D affine maps. 1073 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 1074 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 1075 if (cst.getValue() == 0) 1076 return MemRefType::Builder(t).setLayout({}); 1077 return t; 1078 } 1079 1080 // 0-D corner case for empty shape that still have an affine map. Example: 1081 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 1082 // offset needs to remain, just return t. 1083 if (t.getShape().empty()) 1084 return t; 1085 1086 // If the canonical strided layout for the sizes of `t` is equal to the 1087 // simplified layout of `t` we can just return an empty layout. Otherwise, 1088 // just simplify the existing layout. 1089 AffineExpr expr = 1090 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 1091 auto simplifiedLayoutExpr = 1092 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 1093 if (expr != simplifiedLayoutExpr) 1094 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( 1095 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); 1096 return MemRefType::Builder(t).setLayout({}); 1097 } 1098 1099 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1100 ArrayRef<AffineExpr> exprs, 1101 MLIRContext *context) { 1102 assert(!sizes.empty() && !exprs.empty() && 1103 "expected non-empty sizes and exprs"); 1104 1105 // Size 0 corner case is useful for canonicalizations. 1106 if (llvm::is_contained(sizes, 0)) 1107 return getAffineConstantExpr(0, context); 1108 1109 auto maps = AffineMap::inferFromExprList(exprs); 1110 assert(!maps.empty() && "Expected one non-empty map"); 1111 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 1112 1113 AffineExpr expr; 1114 bool dynamicPoisonBit = false; 1115 int64_t runningSize = 1; 1116 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 1117 int64_t size = std::get<1>(en); 1118 // Degenerate case, no size =-> no stride 1119 if (size == 0) 1120 continue; 1121 AffineExpr dimExpr = std::get<0>(en); 1122 AffineExpr stride = dynamicPoisonBit 1123 ? getAffineSymbolExpr(nSymbols++, context) 1124 : getAffineConstantExpr(runningSize, context); 1125 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 1126 if (size > 0) { 1127 runningSize *= size; 1128 assert(runningSize > 0 && "integer overflow in size computation"); 1129 } else { 1130 dynamicPoisonBit = true; 1131 } 1132 } 1133 return simplifyAffineExpr(expr, numDims, nSymbols); 1134 } 1135 1136 /// Return a version of `t` with a layout that has all dynamic offset and 1137 /// strides. This is used to erase the static layout. 1138 MemRefType mlir::eraseStridedLayout(MemRefType t) { 1139 auto val = ShapedType::kDynamicStrideOrOffset; 1140 return MemRefType::Builder(t).setLayout( 1141 AffineMapAttr::get(makeStridedLinearLayoutMap( 1142 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()))); 1143 } 1144 1145 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1146 MLIRContext *context) { 1147 SmallVector<AffineExpr, 4> exprs; 1148 exprs.reserve(sizes.size()); 1149 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 1150 exprs.push_back(getAffineDimExpr(dim, context)); 1151 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 1152 } 1153 1154 /// Return true if the layout for `t` is compatible with strided semantics. 1155 bool mlir::isStrided(MemRefType t) { 1156 int64_t offset; 1157 SmallVector<int64_t, 4> strides; 1158 auto res = getStridesAndOffset(t, strides, offset); 1159 return succeeded(res); 1160 } 1161 1162 /// Return the layout map in strided linear layout AffineMap form. 1163 /// Return null if the layout is not compatible with a strided layout. 1164 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1165 int64_t offset; 1166 SmallVector<int64_t, 4> strides; 1167 if (failed(getStridesAndOffset(t, strides, offset))) 1168 return AffineMap(); 1169 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1170 } 1171 1172 /// Return the AffineExpr representation of the offset, assuming `memRefType` 1173 /// is a strided memref. 1174 static AffineExpr getOffsetExpr(MemRefType memrefType) { 1175 SmallVector<AffineExpr> strides; 1176 AffineExpr offset; 1177 if (failed(getStridesAndOffset(memrefType, strides, offset))) 1178 assert(false && "expected strided memref"); 1179 return offset; 1180 } 1181 1182 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and 1183 /// `offset` AffineExpr. 1184 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, 1185 ArrayRef<int64_t> shape, 1186 Type elementType, 1187 AffineExpr offset) { 1188 AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); 1189 AffineExpr contiguousRowMajor = canonical + offset; 1190 AffineMap contiguousRowMajorMap = 1191 AffineMap::inferFromExprList({contiguousRowMajor})[0]; 1192 return MemRefType::get(shape, elementType, contiguousRowMajorMap); 1193 } 1194 1195 /// Helper determining if a memref is static-shape and contiguous-row-major 1196 /// layout, while still allowing for an arbitrary offset (any static or 1197 /// dynamic value). 1198 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { 1199 if (!memrefType.hasStaticShape()) 1200 return false; 1201 AffineExpr offset = getOffsetExpr(memrefType); 1202 MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( 1203 memrefType.getContext(), memrefType.getShape(), 1204 memrefType.getElementType(), offset); 1205 return canonicalizeStridedLayout(memrefType) == 1206 canonicalizeStridedLayout(contiguousRowMajorMemRefType); 1207 } 1208