1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/TensorEncoding.h" 19 #include "llvm/ADT/APFloat.h" 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::detail; 27 28 //===----------------------------------------------------------------------===// 29 /// Tablegen Type Definitions 30 //===----------------------------------------------------------------------===// 31 32 #define GET_TYPEDEF_CLASSES 33 #include "mlir/IR/BuiltinTypes.cpp.inc" 34 35 //===----------------------------------------------------------------------===// 36 /// Tablegen Interface Definitions 37 //===----------------------------------------------------------------------===// 38 39 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" 40 41 //===----------------------------------------------------------------------===// 42 // BuiltinDialect 43 //===----------------------------------------------------------------------===// 44 45 void BuiltinDialect::registerTypes() { 46 addTypes< 47 #define GET_TYPEDEF_LIST 48 #include "mlir/IR/BuiltinTypes.cpp.inc" 49 >(); 50 } 51 52 //===----------------------------------------------------------------------===// 53 /// ComplexType 54 //===----------------------------------------------------------------------===// 55 56 /// Verify the construction of an integer type. 57 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 58 Type elementType) { 59 if (!elementType.isIntOrFloat()) 60 return emitError() << "invalid element type for complex"; 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // Integer Type 66 //===----------------------------------------------------------------------===// 67 68 // static constexpr must have a definition (until in C++17 and inline variable). 69 constexpr unsigned IntegerType::kMaxWidth; 70 71 /// Verify the construction of an integer type. 72 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 73 unsigned width, 74 SignednessSemantics signedness) { 75 if (width > IntegerType::kMaxWidth) { 76 return emitError() << "integer bitwidth is limited to " 77 << IntegerType::kMaxWidth << " bits"; 78 } 79 return success(); 80 } 81 82 unsigned IntegerType::getWidth() const { return getImpl()->width; } 83 84 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 85 return getImpl()->signedness; 86 } 87 88 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 89 if (!scale) 90 return IntegerType(); 91 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Float Type 96 //===----------------------------------------------------------------------===// 97 98 unsigned FloatType::getWidth() { 99 if (isa<Float16Type, BFloat16Type>()) 100 return 16; 101 if (isa<Float32Type>()) 102 return 32; 103 if (isa<Float64Type>()) 104 return 64; 105 if (isa<Float80Type>()) 106 return 80; 107 if (isa<Float128Type>()) 108 return 128; 109 llvm_unreachable("unexpected float type"); 110 } 111 112 /// Returns the floating semantics for the given type. 113 const llvm::fltSemantics &FloatType::getFloatSemantics() { 114 if (isa<BFloat16Type>()) 115 return APFloat::BFloat(); 116 if (isa<Float16Type>()) 117 return APFloat::IEEEhalf(); 118 if (isa<Float32Type>()) 119 return APFloat::IEEEsingle(); 120 if (isa<Float64Type>()) 121 return APFloat::IEEEdouble(); 122 if (isa<Float80Type>()) 123 return APFloat::x87DoubleExtended(); 124 if (isa<Float128Type>()) 125 return APFloat::IEEEquad(); 126 llvm_unreachable("non-floating point type used"); 127 } 128 129 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 130 if (!scale) 131 return FloatType(); 132 MLIRContext *ctx = getContext(); 133 if (isF16() || isBF16()) { 134 if (scale == 2) 135 return FloatType::getF32(ctx); 136 if (scale == 4) 137 return FloatType::getF64(ctx); 138 } 139 if (isF32()) 140 if (scale == 2) 141 return FloatType::getF64(ctx); 142 return FloatType(); 143 } 144 145 //===----------------------------------------------------------------------===// 146 // FunctionType 147 //===----------------------------------------------------------------------===// 148 149 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 150 151 ArrayRef<Type> FunctionType::getInputs() const { 152 return getImpl()->getInputs(); 153 } 154 155 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 156 157 ArrayRef<Type> FunctionType::getResults() const { 158 return getImpl()->getResults(); 159 } 160 161 /// Helper to call a callback once on each index in the range 162 /// [0, `totalIndices`), *except* for the indices given in `indices`. 163 /// `indices` is allowed to have duplicates and can be in any order. 164 inline void iterateIndicesExcept(unsigned totalIndices, 165 ArrayRef<unsigned> indices, 166 function_ref<void(unsigned)> callback) { 167 llvm::BitVector skipIndices(totalIndices); 168 for (unsigned i : indices) 169 skipIndices.set(i); 170 171 for (unsigned i = 0; i < totalIndices; ++i) 172 if (!skipIndices.test(i)) 173 callback(i); 174 } 175 176 /// Returns a new function type with the specified arguments and results 177 /// inserted. 178 FunctionType FunctionType::getWithArgsAndResults( 179 ArrayRef<unsigned> argIndices, TypeRange argTypes, 180 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 181 assert(argIndices.size() == argTypes.size()); 182 assert(resultIndices.size() == resultTypes.size()); 183 184 ArrayRef<Type> newInputTypes = getInputs(); 185 SmallVector<Type, 4> newInputTypesBuffer; 186 if (!argIndices.empty()) { 187 const auto *fromIt = newInputTypes.begin(); 188 for (auto it : llvm::zip(argIndices, argTypes)) { 189 const auto *toIt = newInputTypes.begin() + std::get<0>(it); 190 newInputTypesBuffer.append(fromIt, toIt); 191 newInputTypesBuffer.push_back(std::get<1>(it)); 192 fromIt = toIt; 193 } 194 newInputTypesBuffer.append(fromIt, newInputTypes.end()); 195 newInputTypes = newInputTypesBuffer; 196 } 197 198 ArrayRef<Type> newResultTypes = getResults(); 199 SmallVector<Type, 4> newResultTypesBuffer; 200 if (!resultIndices.empty()) { 201 const auto *fromIt = newResultTypes.begin(); 202 for (auto it : llvm::zip(resultIndices, resultTypes)) { 203 const auto *toIt = newResultTypes.begin() + std::get<0>(it); 204 newResultTypesBuffer.append(fromIt, toIt); 205 newResultTypesBuffer.push_back(std::get<1>(it)); 206 fromIt = toIt; 207 } 208 newResultTypesBuffer.append(fromIt, newResultTypes.end()); 209 newResultTypes = newResultTypesBuffer; 210 } 211 212 return FunctionType::get(getContext(), newInputTypes, newResultTypes); 213 } 214 215 /// Returns a new function type without the specified arguments and results. 216 FunctionType 217 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 218 ArrayRef<unsigned> resultIndices) { 219 ArrayRef<Type> newInputTypes = getInputs(); 220 SmallVector<Type, 4> newInputTypesBuffer; 221 if (!argIndices.empty()) { 222 unsigned originalNumArgs = getNumInputs(); 223 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 224 newInputTypesBuffer.emplace_back(getInput(i)); 225 }); 226 newInputTypes = newInputTypesBuffer; 227 } 228 229 ArrayRef<Type> newResultTypes = getResults(); 230 SmallVector<Type, 4> newResultTypesBuffer; 231 if (!resultIndices.empty()) { 232 unsigned originalNumResults = getNumResults(); 233 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 234 newResultTypesBuffer.emplace_back(getResult(i)); 235 }); 236 newResultTypes = newResultTypesBuffer; 237 } 238 239 return get(getContext(), newInputTypes, newResultTypes); 240 } 241 242 void FunctionType::walkImmediateSubElements( 243 function_ref<void(Attribute)> walkAttrsFn, 244 function_ref<void(Type)> walkTypesFn) const { 245 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 246 walkTypesFn(type); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // OpaqueType 251 //===----------------------------------------------------------------------===// 252 253 /// Verify the construction of an opaque type. 254 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 255 StringAttr dialect, StringRef typeData) { 256 if (!Dialect::isValidNamespace(dialect.strref())) 257 return emitError() << "invalid dialect namespace '" << dialect << "'"; 258 259 // Check that the dialect is actually registered. 260 MLIRContext *context = dialect.getContext(); 261 if (!context->allowsUnregisteredDialects() && 262 !context->getLoadedDialect(dialect.strref())) { 263 return emitError() 264 << "`!" << dialect << "<\"" << typeData << "\">" 265 << "` type created with unregistered dialect. If this is " 266 "intended, please call allowUnregisteredDialects() on the " 267 "MLIRContext, or use -allow-unregistered-dialect with " 268 "the MLIR opt tool used"; 269 } 270 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // ShapedType 276 //===----------------------------------------------------------------------===// 277 constexpr int64_t ShapedType::kDynamicSize; 278 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 279 280 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 281 if (auto other = dyn_cast<MemRefType>()) { 282 MemRefType::Builder b(other); 283 b.setShape(shape); 284 b.setElementType(elementType); 285 return b; 286 } 287 288 if (auto other = dyn_cast<UnrankedMemRefType>()) { 289 MemRefType::Builder b(shape, elementType); 290 b.setMemorySpace(other.getMemorySpace()); 291 return b; 292 } 293 294 if (isa<TensorType>()) 295 return RankedTensorType::get(shape, elementType); 296 297 if (auto vecTy = dyn_cast<VectorType>()) 298 return VectorType::get(shape, elementType, vecTy.getNumScalableDims()); 299 300 llvm_unreachable("Unhandled ShapedType clone case"); 301 } 302 303 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 304 if (auto other = dyn_cast<MemRefType>()) { 305 MemRefType::Builder b(other); 306 b.setShape(shape); 307 return b; 308 } 309 310 if (auto other = dyn_cast<UnrankedMemRefType>()) { 311 MemRefType::Builder b(shape, other.getElementType()); 312 b.setShape(shape); 313 b.setMemorySpace(other.getMemorySpace()); 314 return b; 315 } 316 317 if (isa<TensorType>()) 318 return RankedTensorType::get(shape, getElementType()); 319 320 if (auto vecTy = dyn_cast<VectorType>()) 321 return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims()); 322 323 llvm_unreachable("Unhandled ShapedType clone case"); 324 } 325 326 ShapedType ShapedType::clone(Type elementType) { 327 if (auto other = dyn_cast<MemRefType>()) { 328 MemRefType::Builder b(other); 329 b.setElementType(elementType); 330 return b; 331 } 332 333 if (auto other = dyn_cast<UnrankedMemRefType>()) { 334 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 335 } 336 337 if (isa<TensorType>()) { 338 if (hasRank()) 339 return RankedTensorType::get(getShape(), elementType); 340 return UnrankedTensorType::get(elementType); 341 } 342 343 if (auto vecTy = dyn_cast<VectorType>()) 344 return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims()); 345 346 llvm_unreachable("Unhandled ShapedType clone hit"); 347 } 348 349 Type ShapedType::getElementType() const { 350 return TypeSwitch<Type, Type>(*this) 351 .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType, 352 UnrankedMemRefType>([](auto ty) { return ty.getElementType(); }); 353 } 354 355 unsigned ShapedType::getElementTypeBitWidth() const { 356 return getElementType().getIntOrFloatBitWidth(); 357 } 358 359 int64_t ShapedType::getNumElements() const { 360 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 361 auto shape = getShape(); 362 int64_t num = 1; 363 for (auto dim : shape) { 364 num *= dim; 365 assert(num >= 0 && "integer overflow in element count computation"); 366 } 367 return num; 368 } 369 370 int64_t ShapedType::getRank() const { 371 assert(hasRank() && "cannot query rank of unranked shaped type"); 372 return getShape().size(); 373 } 374 375 bool ShapedType::hasRank() const { 376 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 377 } 378 379 int64_t ShapedType::getDimSize(unsigned idx) const { 380 assert(idx < getRank() && "invalid index for shaped type"); 381 return getShape()[idx]; 382 } 383 384 bool ShapedType::isDynamicDim(unsigned idx) const { 385 assert(idx < getRank() && "invalid index for shaped type"); 386 return isDynamic(getShape()[idx]); 387 } 388 389 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 390 assert(index < getRank() && "invalid index"); 391 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 392 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 393 } 394 395 /// Get the number of bits require to store a value of the given shaped type. 396 /// Compute the value recursively since tensors are allowed to have vectors as 397 /// elements. 398 int64_t ShapedType::getSizeInBits() const { 399 assert(hasStaticShape() && 400 "cannot get the bit size of an aggregate with a dynamic shape"); 401 402 auto elementType = getElementType(); 403 if (elementType.isIntOrFloat()) 404 return elementType.getIntOrFloatBitWidth() * getNumElements(); 405 406 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 407 elementType = complexType.getElementType(); 408 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 409 } 410 411 // Tensors can have vectors and other tensors as elements, other shaped types 412 // cannot. 413 assert(isa<TensorType>() && "unsupported element type"); 414 assert((elementType.isa<VectorType, TensorType>()) && 415 "unsupported tensor element type"); 416 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 417 } 418 419 ArrayRef<int64_t> ShapedType::getShape() const { 420 if (auto vectorType = dyn_cast<VectorType>()) 421 return vectorType.getShape(); 422 if (auto tensorType = dyn_cast<RankedTensorType>()) 423 return tensorType.getShape(); 424 return cast<MemRefType>().getShape(); 425 } 426 427 int64_t ShapedType::getNumDynamicDims() const { 428 return llvm::count_if(getShape(), isDynamic); 429 } 430 431 bool ShapedType::hasStaticShape() const { 432 return hasRank() && llvm::none_of(getShape(), isDynamic); 433 } 434 435 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 436 return hasStaticShape() && getShape() == shape; 437 } 438 439 //===----------------------------------------------------------------------===// 440 // VectorType 441 //===----------------------------------------------------------------------===// 442 443 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 444 ArrayRef<int64_t> shape, Type elementType, 445 unsigned numScalableDims) { 446 if (!isValidElementType(elementType)) 447 return emitError() 448 << "vector elements must be int/index/float type but got " 449 << elementType; 450 451 if (any_of(shape, [](int64_t i) { return i <= 0; })) 452 return emitError() 453 << "vector types must have positive constant sizes but got " 454 << shape; 455 456 return success(); 457 } 458 459 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 460 if (!scale) 461 return VectorType(); 462 if (auto et = getElementType().dyn_cast<IntegerType>()) 463 if (auto scaledEt = et.scaleElementBitwidth(scale)) 464 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 465 if (auto et = getElementType().dyn_cast<FloatType>()) 466 if (auto scaledEt = et.scaleElementBitwidth(scale)) 467 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 468 return VectorType(); 469 } 470 471 void VectorType::walkImmediateSubElements( 472 function_ref<void(Attribute)> walkAttrsFn, 473 function_ref<void(Type)> walkTypesFn) const { 474 walkTypesFn(getElementType()); 475 } 476 477 //===----------------------------------------------------------------------===// 478 // TensorType 479 //===----------------------------------------------------------------------===// 480 481 // Check if "elementType" can be an element type of a tensor. 482 static LogicalResult 483 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 484 Type elementType) { 485 if (!TensorType::isValidElementType(elementType)) 486 return emitError() << "invalid tensor element type: " << elementType; 487 return success(); 488 } 489 490 /// Return true if the specified element type is ok in a tensor. 491 bool TensorType::isValidElementType(Type type) { 492 // Note: Non standard/builtin types are allowed to exist within tensor 493 // types. Dialects are expected to verify that tensor types have a valid 494 // element type within that dialect. 495 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 496 IndexType>() || 497 !llvm::isa<BuiltinDialect>(type.getDialect()); 498 } 499 500 //===----------------------------------------------------------------------===// 501 // RankedTensorType 502 //===----------------------------------------------------------------------===// 503 504 LogicalResult 505 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 506 ArrayRef<int64_t> shape, Type elementType, 507 Attribute encoding) { 508 for (int64_t s : shape) 509 if (s < -1) 510 return emitError() << "invalid tensor dimension size"; 511 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 512 if (failed(v.verifyEncoding(shape, elementType, emitError))) 513 return failure(); 514 return checkTensorElementType(emitError, elementType); 515 } 516 517 void RankedTensorType::walkImmediateSubElements( 518 function_ref<void(Attribute)> walkAttrsFn, 519 function_ref<void(Type)> walkTypesFn) const { 520 walkTypesFn(getElementType()); 521 if (Attribute encoding = getEncoding()) 522 walkAttrsFn(encoding); 523 } 524 525 //===----------------------------------------------------------------------===// 526 // UnrankedTensorType 527 //===----------------------------------------------------------------------===// 528 529 LogicalResult 530 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 531 Type elementType) { 532 return checkTensorElementType(emitError, elementType); 533 } 534 535 void UnrankedTensorType::walkImmediateSubElements( 536 function_ref<void(Attribute)> walkAttrsFn, 537 function_ref<void(Type)> walkTypesFn) const { 538 walkTypesFn(getElementType()); 539 } 540 541 //===----------------------------------------------------------------------===// 542 // BaseMemRefType 543 //===----------------------------------------------------------------------===// 544 545 Attribute BaseMemRefType::getMemorySpace() const { 546 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 547 return rankedMemRefTy.getMemorySpace(); 548 return cast<UnrankedMemRefType>().getMemorySpace(); 549 } 550 551 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 552 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 553 return rankedMemRefTy.getMemorySpaceAsInt(); 554 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 555 } 556 557 //===----------------------------------------------------------------------===// 558 // MemRefType 559 //===----------------------------------------------------------------------===// 560 561 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 562 /// `originalShape` with some `1` entries erased, return the set of indices 563 /// that specifies which of the entries of `originalShape` are dropped to obtain 564 /// `reducedShape`. The returned mask can be applied as a projection to 565 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 566 /// which dimensions must be kept when e.g. compute MemRef strides under 567 /// rank-reducing operations. Return None if reducedShape cannot be obtained 568 /// by dropping only `1` entries in `originalShape`. 569 llvm::Optional<llvm::SmallDenseSet<unsigned>> 570 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 571 ArrayRef<int64_t> reducedShape) { 572 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 573 llvm::SmallDenseSet<unsigned> unusedDims; 574 unsigned reducedIdx = 0; 575 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 576 // Greedily insert `originalIdx` if match. 577 if (reducedIdx < reducedRank && 578 originalShape[originalIdx] == reducedShape[reducedIdx]) { 579 reducedIdx++; 580 continue; 581 } 582 583 unusedDims.insert(originalIdx); 584 // If no match on `originalIdx`, the `originalShape` at this dimension 585 // must be 1, otherwise we bail. 586 if (originalShape[originalIdx] != 1) 587 return llvm::None; 588 } 589 // The whole reducedShape must be scanned, otherwise we bail. 590 if (reducedIdx != reducedRank) 591 return llvm::None; 592 return unusedDims; 593 } 594 595 SliceVerificationResult 596 mlir::isRankReducedType(ShapedType originalType, 597 ShapedType candidateReducedType) { 598 if (originalType == candidateReducedType) 599 return SliceVerificationResult::Success; 600 601 ShapedType originalShapedType = originalType.cast<ShapedType>(); 602 ShapedType candidateReducedShapedType = 603 candidateReducedType.cast<ShapedType>(); 604 605 // Rank and size logic is valid for all ShapedTypes. 606 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 607 ArrayRef<int64_t> candidateReducedShape = 608 candidateReducedShapedType.getShape(); 609 unsigned originalRank = originalShape.size(), 610 candidateReducedRank = candidateReducedShape.size(); 611 if (candidateReducedRank > originalRank) 612 return SliceVerificationResult::RankTooLarge; 613 614 auto optionalUnusedDimsMask = 615 computeRankReductionMask(originalShape, candidateReducedShape); 616 617 // Sizes cannot be matched in case empty vector is returned. 618 if (!optionalUnusedDimsMask.hasValue()) 619 return SliceVerificationResult::SizeMismatch; 620 621 if (originalShapedType.getElementType() != 622 candidateReducedShapedType.getElementType()) 623 return SliceVerificationResult::ElemTypeMismatch; 624 625 return SliceVerificationResult::Success; 626 } 627 628 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 629 // Empty attribute is allowed as default memory space. 630 if (!memorySpace) 631 return true; 632 633 // Supported built-in attributes. 634 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 635 return true; 636 637 // Allow custom dialect attributes. 638 if (!isa<BuiltinDialect>(memorySpace.getDialect())) 639 return true; 640 641 return false; 642 } 643 644 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 645 MLIRContext *ctx) { 646 if (memorySpace == 0) 647 return nullptr; 648 649 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 650 } 651 652 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 653 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 654 if (intMemorySpace && intMemorySpace.getValue() == 0) 655 return nullptr; 656 657 return memorySpace; 658 } 659 660 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 661 if (!memorySpace) 662 return 0; 663 664 assert(memorySpace.isa<IntegerAttr>() && 665 "Using `getMemorySpaceInteger` with non-Integer attribute"); 666 667 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 668 } 669 670 MemRefType::Builder & 671 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 672 memorySpace = 673 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 674 return *this; 675 } 676 677 unsigned MemRefType::getMemorySpaceAsInt() const { 678 return detail::getMemorySpaceAsInt(getMemorySpace()); 679 } 680 681 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 682 MemRefLayoutAttrInterface layout, 683 Attribute memorySpace) { 684 // Use default layout for empty attribute. 685 if (!layout) 686 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 687 shape.size(), elementType.getContext())); 688 689 // Drop default memory space value and replace it with empty attribute. 690 memorySpace = skipDefaultMemorySpace(memorySpace); 691 692 return Base::get(elementType.getContext(), shape, elementType, layout, 693 memorySpace); 694 } 695 696 MemRefType MemRefType::getChecked( 697 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 698 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 699 700 // Use default layout for empty attribute. 701 if (!layout) 702 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 703 shape.size(), elementType.getContext())); 704 705 // Drop default memory space value and replace it with empty attribute. 706 memorySpace = skipDefaultMemorySpace(memorySpace); 707 708 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 709 elementType, layout, memorySpace); 710 } 711 712 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 713 AffineMap map, Attribute memorySpace) { 714 715 // Use default layout for empty map. 716 if (!map) 717 map = AffineMap::getMultiDimIdentityMap(shape.size(), 718 elementType.getContext()); 719 720 // Wrap AffineMap into Attribute. 721 Attribute layout = AffineMapAttr::get(map); 722 723 // Drop default memory space value and replace it with empty attribute. 724 memorySpace = skipDefaultMemorySpace(memorySpace); 725 726 return Base::get(elementType.getContext(), shape, elementType, layout, 727 memorySpace); 728 } 729 730 MemRefType 731 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 732 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 733 Attribute memorySpace) { 734 735 // Use default layout for empty map. 736 if (!map) 737 map = AffineMap::getMultiDimIdentityMap(shape.size(), 738 elementType.getContext()); 739 740 // Wrap AffineMap into Attribute. 741 Attribute layout = AffineMapAttr::get(map); 742 743 // Drop default memory space value and replace it with empty attribute. 744 memorySpace = skipDefaultMemorySpace(memorySpace); 745 746 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 747 elementType, layout, memorySpace); 748 } 749 750 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 751 AffineMap map, unsigned memorySpaceInd) { 752 753 // Use default layout for empty map. 754 if (!map) 755 map = AffineMap::getMultiDimIdentityMap(shape.size(), 756 elementType.getContext()); 757 758 // Wrap AffineMap into Attribute. 759 Attribute layout = AffineMapAttr::get(map); 760 761 // Convert deprecated integer-like memory space to Attribute. 762 Attribute memorySpace = 763 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 764 765 return Base::get(elementType.getContext(), shape, elementType, layout, 766 memorySpace); 767 } 768 769 MemRefType 770 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 771 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 772 unsigned memorySpaceInd) { 773 774 // Use default layout for empty map. 775 if (!map) 776 map = AffineMap::getMultiDimIdentityMap(shape.size(), 777 elementType.getContext()); 778 779 // Wrap AffineMap into Attribute. 780 Attribute layout = AffineMapAttr::get(map); 781 782 // Convert deprecated integer-like memory space to Attribute. 783 Attribute memorySpace = 784 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 785 786 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 787 elementType, layout, memorySpace); 788 } 789 790 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 791 ArrayRef<int64_t> shape, Type elementType, 792 MemRefLayoutAttrInterface layout, 793 Attribute memorySpace) { 794 if (!BaseMemRefType::isValidElementType(elementType)) 795 return emitError() << "invalid memref element type"; 796 797 // Negative sizes are not allowed except for `-1` that means dynamic size. 798 for (int64_t s : shape) 799 if (s < -1) 800 return emitError() << "invalid memref size"; 801 802 assert(layout && "missing layout specification"); 803 if (failed(layout.verifyLayout(shape, emitError))) 804 return failure(); 805 806 if (!isSupportedMemorySpace(memorySpace)) 807 return emitError() << "unsupported memory space Attribute"; 808 809 return success(); 810 } 811 812 void MemRefType::walkImmediateSubElements( 813 function_ref<void(Attribute)> walkAttrsFn, 814 function_ref<void(Type)> walkTypesFn) const { 815 walkTypesFn(getElementType()); 816 if (!getLayout().isIdentity()) 817 walkAttrsFn(getLayout()); 818 walkAttrsFn(getMemorySpace()); 819 } 820 821 //===----------------------------------------------------------------------===// 822 // UnrankedMemRefType 823 //===----------------------------------------------------------------------===// 824 825 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 826 return detail::getMemorySpaceAsInt(getMemorySpace()); 827 } 828 829 LogicalResult 830 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 831 Type elementType, Attribute memorySpace) { 832 if (!BaseMemRefType::isValidElementType(elementType)) 833 return emitError() << "invalid memref element type"; 834 835 if (!isSupportedMemorySpace(memorySpace)) 836 return emitError() << "unsupported memory space Attribute"; 837 838 return success(); 839 } 840 841 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 842 // i.e. single term). Accumulate the AffineExpr into the existing one. 843 static void extractStridesFromTerm(AffineExpr e, 844 AffineExpr multiplicativeFactor, 845 MutableArrayRef<AffineExpr> strides, 846 AffineExpr &offset) { 847 if (auto dim = e.dyn_cast<AffineDimExpr>()) 848 strides[dim.getPosition()] = 849 strides[dim.getPosition()] + multiplicativeFactor; 850 else 851 offset = offset + e * multiplicativeFactor; 852 } 853 854 /// Takes a single AffineExpr `e` and populates the `strides` array with the 855 /// strides expressions for each dim position. 856 /// The convention is that the strides for dimensions d0, .. dn appear in 857 /// order to make indexing intuitive into the result. 858 static LogicalResult extractStrides(AffineExpr e, 859 AffineExpr multiplicativeFactor, 860 MutableArrayRef<AffineExpr> strides, 861 AffineExpr &offset) { 862 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 863 if (!bin) { 864 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 865 return success(); 866 } 867 868 if (bin.getKind() == AffineExprKind::CeilDiv || 869 bin.getKind() == AffineExprKind::FloorDiv || 870 bin.getKind() == AffineExprKind::Mod) 871 return failure(); 872 873 if (bin.getKind() == AffineExprKind::Mul) { 874 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 875 if (dim) { 876 strides[dim.getPosition()] = 877 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 878 return success(); 879 } 880 // LHS and RHS may both contain complex expressions of dims. Try one path 881 // and if it fails try the other. This is guaranteed to succeed because 882 // only one path may have a `dim`, otherwise this is not an AffineExpr in 883 // the first place. 884 if (bin.getLHS().isSymbolicOrConstant()) 885 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 886 strides, offset); 887 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 888 strides, offset); 889 } 890 891 if (bin.getKind() == AffineExprKind::Add) { 892 auto res1 = 893 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 894 auto res2 = 895 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 896 return success(succeeded(res1) && succeeded(res2)); 897 } 898 899 llvm_unreachable("unexpected binary operation"); 900 } 901 902 LogicalResult mlir::getStridesAndOffset(MemRefType t, 903 SmallVectorImpl<AffineExpr> &strides, 904 AffineExpr &offset) { 905 AffineMap m = t.getLayout().getAffineMap(); 906 907 if (m.getNumResults() != 1 && !m.isIdentity()) 908 return failure(); 909 910 auto zero = getAffineConstantExpr(0, t.getContext()); 911 auto one = getAffineConstantExpr(1, t.getContext()); 912 offset = zero; 913 strides.assign(t.getRank(), zero); 914 915 // Canonical case for empty map. 916 if (m.isIdentity()) { 917 // 0-D corner case, offset is already 0. 918 if (t.getRank() == 0) 919 return success(); 920 auto stridedExpr = 921 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 922 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 923 return success(); 924 assert(false && "unexpected failure: extract strides in canonical layout"); 925 } 926 927 // Non-canonical case requires more work. 928 auto stridedExpr = 929 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 930 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 931 offset = AffineExpr(); 932 strides.clear(); 933 return failure(); 934 } 935 936 // Simplify results to allow folding to constants and simple checks. 937 unsigned numDims = m.getNumDims(); 938 unsigned numSymbols = m.getNumSymbols(); 939 offset = simplifyAffineExpr(offset, numDims, numSymbols); 940 for (auto &stride : strides) 941 stride = simplifyAffineExpr(stride, numDims, numSymbols); 942 943 /// In practice, a strided memref must be internally non-aliasing. Test 944 /// against 0 as a proxy. 945 /// TODO: static cases can have more advanced checks. 946 /// TODO: dynamic cases would require a way to compare symbolic 947 /// expressions and would probably need an affine set context propagated 948 /// everywhere. 949 if (llvm::any_of(strides, [](AffineExpr e) { 950 return e == getAffineConstantExpr(0, e.getContext()); 951 })) { 952 offset = AffineExpr(); 953 strides.clear(); 954 return failure(); 955 } 956 957 return success(); 958 } 959 960 LogicalResult mlir::getStridesAndOffset(MemRefType t, 961 SmallVectorImpl<int64_t> &strides, 962 int64_t &offset) { 963 AffineExpr offsetExpr; 964 SmallVector<AffineExpr, 4> strideExprs; 965 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 966 return failure(); 967 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 968 offset = cst.getValue(); 969 else 970 offset = ShapedType::kDynamicStrideOrOffset; 971 for (auto e : strideExprs) { 972 if (auto c = e.dyn_cast<AffineConstantExpr>()) 973 strides.push_back(c.getValue()); 974 else 975 strides.push_back(ShapedType::kDynamicStrideOrOffset); 976 } 977 return success(); 978 } 979 980 void UnrankedMemRefType::walkImmediateSubElements( 981 function_ref<void(Attribute)> walkAttrsFn, 982 function_ref<void(Type)> walkTypesFn) const { 983 walkTypesFn(getElementType()); 984 walkAttrsFn(getMemorySpace()); 985 } 986 987 //===----------------------------------------------------------------------===// 988 /// TupleType 989 //===----------------------------------------------------------------------===// 990 991 /// Return the elements types for this tuple. 992 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 993 994 /// Accumulate the types contained in this tuple and tuples nested within it. 995 /// Note that this only flattens nested tuples, not any other container type, 996 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 997 /// (i32, tensor<i32>, f32, i64) 998 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 999 for (Type type : getTypes()) { 1000 if (auto nestedTuple = type.dyn_cast<TupleType>()) 1001 nestedTuple.getFlattenedTypes(types); 1002 else 1003 types.push_back(type); 1004 } 1005 } 1006 1007 /// Return the number of element types. 1008 size_t TupleType::size() const { return getImpl()->size(); } 1009 1010 void TupleType::walkImmediateSubElements( 1011 function_ref<void(Attribute)> walkAttrsFn, 1012 function_ref<void(Type)> walkTypesFn) const { 1013 for (Type type : getTypes()) 1014 walkTypesFn(type); 1015 } 1016 1017 //===----------------------------------------------------------------------===// 1018 // Type Utilities 1019 //===----------------------------------------------------------------------===// 1020 1021 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 1022 int64_t offset, 1023 MLIRContext *context) { 1024 AffineExpr expr; 1025 unsigned nSymbols = 0; 1026 1027 // AffineExpr for offset. 1028 // Static case. 1029 if (offset != MemRefType::getDynamicStrideOrOffset()) { 1030 auto cst = getAffineConstantExpr(offset, context); 1031 expr = cst; 1032 } else { 1033 // Dynamic case, new symbol for the offset. 1034 auto sym = getAffineSymbolExpr(nSymbols++, context); 1035 expr = sym; 1036 } 1037 1038 // AffineExpr for strides. 1039 for (const auto &en : llvm::enumerate(strides)) { 1040 auto dim = en.index(); 1041 auto stride = en.value(); 1042 assert(stride != 0 && "Invalid stride specification"); 1043 auto d = getAffineDimExpr(dim, context); 1044 AffineExpr mult; 1045 // Static case. 1046 if (stride != MemRefType::getDynamicStrideOrOffset()) 1047 mult = getAffineConstantExpr(stride, context); 1048 else 1049 // Dynamic case, new symbol for each new stride. 1050 mult = getAffineSymbolExpr(nSymbols++, context); 1051 expr = expr + d * mult; 1052 } 1053 1054 return AffineMap::get(strides.size(), nSymbols, expr); 1055 } 1056 1057 /// Return a version of `t` with identity layout if it can be determined 1058 /// statically that the layout is the canonical contiguous strided layout. 1059 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 1060 /// `t` with simplified layout. 1061 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 1062 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 1063 AffineMap m = t.getLayout().getAffineMap(); 1064 1065 // Already in canonical form. 1066 if (m.isIdentity()) 1067 return t; 1068 1069 // Can't reduce to canonical identity form, return in canonical form. 1070 if (m.getNumResults() > 1) 1071 return t; 1072 1073 // Corner-case for 0-D affine maps. 1074 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 1075 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 1076 if (cst.getValue() == 0) 1077 return MemRefType::Builder(t).setLayout({}); 1078 return t; 1079 } 1080 1081 // 0-D corner case for empty shape that still have an affine map. Example: 1082 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 1083 // offset needs to remain, just return t. 1084 if (t.getShape().empty()) 1085 return t; 1086 1087 // If the canonical strided layout for the sizes of `t` is equal to the 1088 // simplified layout of `t` we can just return an empty layout. Otherwise, 1089 // just simplify the existing layout. 1090 AffineExpr expr = 1091 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 1092 auto simplifiedLayoutExpr = 1093 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 1094 if (expr != simplifiedLayoutExpr) 1095 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( 1096 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); 1097 return MemRefType::Builder(t).setLayout({}); 1098 } 1099 1100 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1101 ArrayRef<AffineExpr> exprs, 1102 MLIRContext *context) { 1103 assert(!sizes.empty() && !exprs.empty() && 1104 "expected non-empty sizes and exprs"); 1105 1106 // Size 0 corner case is useful for canonicalizations. 1107 if (llvm::is_contained(sizes, 0)) 1108 return getAffineConstantExpr(0, context); 1109 1110 auto maps = AffineMap::inferFromExprList(exprs); 1111 assert(!maps.empty() && "Expected one non-empty map"); 1112 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 1113 1114 AffineExpr expr; 1115 bool dynamicPoisonBit = false; 1116 int64_t runningSize = 1; 1117 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 1118 int64_t size = std::get<1>(en); 1119 // Degenerate case, no size =-> no stride 1120 if (size == 0) 1121 continue; 1122 AffineExpr dimExpr = std::get<0>(en); 1123 AffineExpr stride = dynamicPoisonBit 1124 ? getAffineSymbolExpr(nSymbols++, context) 1125 : getAffineConstantExpr(runningSize, context); 1126 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 1127 if (size > 0) { 1128 runningSize *= size; 1129 assert(runningSize > 0 && "integer overflow in size computation"); 1130 } else { 1131 dynamicPoisonBit = true; 1132 } 1133 } 1134 return simplifyAffineExpr(expr, numDims, nSymbols); 1135 } 1136 1137 /// Return a version of `t` with a layout that has all dynamic offset and 1138 /// strides. This is used to erase the static layout. 1139 MemRefType mlir::eraseStridedLayout(MemRefType t) { 1140 auto val = ShapedType::kDynamicStrideOrOffset; 1141 return MemRefType::Builder(t).setLayout( 1142 AffineMapAttr::get(makeStridedLinearLayoutMap( 1143 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()))); 1144 } 1145 1146 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1147 MLIRContext *context) { 1148 SmallVector<AffineExpr, 4> exprs; 1149 exprs.reserve(sizes.size()); 1150 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 1151 exprs.push_back(getAffineDimExpr(dim, context)); 1152 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 1153 } 1154 1155 /// Return true if the layout for `t` is compatible with strided semantics. 1156 bool mlir::isStrided(MemRefType t) { 1157 int64_t offset; 1158 SmallVector<int64_t, 4> strides; 1159 auto res = getStridesAndOffset(t, strides, offset); 1160 return succeeded(res); 1161 } 1162 1163 /// Return the layout map in strided linear layout AffineMap form. 1164 /// Return null if the layout is not compatible with a strided layout. 1165 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1166 int64_t offset; 1167 SmallVector<int64_t, 4> strides; 1168 if (failed(getStridesAndOffset(t, strides, offset))) 1169 return AffineMap(); 1170 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1171 } 1172 1173 /// Return the AffineExpr representation of the offset, assuming `memRefType` 1174 /// is a strided memref. 1175 static AffineExpr getOffsetExpr(MemRefType memrefType) { 1176 SmallVector<AffineExpr> strides; 1177 AffineExpr offset; 1178 if (failed(getStridesAndOffset(memrefType, strides, offset))) 1179 assert(false && "expected strided memref"); 1180 return offset; 1181 } 1182 1183 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and 1184 /// `offset` AffineExpr. 1185 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, 1186 ArrayRef<int64_t> shape, 1187 Type elementType, 1188 AffineExpr offset) { 1189 AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); 1190 AffineExpr contiguousRowMajor = canonical + offset; 1191 AffineMap contiguousRowMajorMap = 1192 AffineMap::inferFromExprList({contiguousRowMajor})[0]; 1193 return MemRefType::get(shape, elementType, contiguousRowMajorMap); 1194 } 1195 1196 /// Helper determining if a memref is static-shape and contiguous-row-major 1197 /// layout, while still allowing for an arbitrary offset (any static or 1198 /// dynamic value). 1199 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { 1200 if (!memrefType.hasStaticShape()) 1201 return false; 1202 AffineExpr offset = getOffsetExpr(memrefType); 1203 MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( 1204 memrefType.getContext(), memrefType.getShape(), 1205 memrefType.getElementType(), offset); 1206 return canonicalizeStridedLayout(memrefType) == 1207 canonicalizeStridedLayout(contiguousRowMajorMemRefType); 1208 } 1209