1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/TensorEncoding.h" 19 #include "llvm/ADT/APFloat.h" 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/Sequence.h" 22 #include "llvm/ADT/Twine.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 25 using namespace mlir; 26 using namespace mlir::detail; 27 28 //===----------------------------------------------------------------------===// 29 /// Tablegen Type Definitions 30 //===----------------------------------------------------------------------===// 31 32 #define GET_TYPEDEF_CLASSES 33 #include "mlir/IR/BuiltinTypes.cpp.inc" 34 35 //===----------------------------------------------------------------------===// 36 // BuiltinDialect 37 //===----------------------------------------------------------------------===// 38 39 void BuiltinDialect::registerTypes() { 40 addTypes< 41 #define GET_TYPEDEF_LIST 42 #include "mlir/IR/BuiltinTypes.cpp.inc" 43 >(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 /// ComplexType 48 //===----------------------------------------------------------------------===// 49 50 /// Verify the construction of an integer type. 51 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 52 Type elementType) { 53 if (!elementType.isIntOrFloat()) 54 return emitError() << "invalid element type for complex"; 55 return success(); 56 } 57 58 //===----------------------------------------------------------------------===// 59 // Integer Type 60 //===----------------------------------------------------------------------===// 61 62 // static constexpr must have a definition (until in C++17 and inline variable). 63 constexpr unsigned IntegerType::kMaxWidth; 64 65 /// Verify the construction of an integer type. 66 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 67 unsigned width, 68 SignednessSemantics signedness) { 69 if (width > IntegerType::kMaxWidth) { 70 return emitError() << "integer bitwidth is limited to " 71 << IntegerType::kMaxWidth << " bits"; 72 } 73 return success(); 74 } 75 76 unsigned IntegerType::getWidth() const { return getImpl()->width; } 77 78 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 79 return getImpl()->signedness; 80 } 81 82 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 83 if (!scale) 84 return IntegerType(); 85 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // Float Type 90 //===----------------------------------------------------------------------===// 91 92 unsigned FloatType::getWidth() { 93 if (isa<Float16Type, BFloat16Type>()) 94 return 16; 95 if (isa<Float32Type>()) 96 return 32; 97 if (isa<Float64Type>()) 98 return 64; 99 if (isa<Float80Type>()) 100 return 80; 101 if (isa<Float128Type>()) 102 return 128; 103 llvm_unreachable("unexpected float type"); 104 } 105 106 /// Returns the floating semantics for the given type. 107 const llvm::fltSemantics &FloatType::getFloatSemantics() { 108 if (isa<BFloat16Type>()) 109 return APFloat::BFloat(); 110 if (isa<Float16Type>()) 111 return APFloat::IEEEhalf(); 112 if (isa<Float32Type>()) 113 return APFloat::IEEEsingle(); 114 if (isa<Float64Type>()) 115 return APFloat::IEEEdouble(); 116 if (isa<Float80Type>()) 117 return APFloat::x87DoubleExtended(); 118 if (isa<Float128Type>()) 119 return APFloat::IEEEquad(); 120 llvm_unreachable("non-floating point type used"); 121 } 122 123 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 124 if (!scale) 125 return FloatType(); 126 MLIRContext *ctx = getContext(); 127 if (isF16() || isBF16()) { 128 if (scale == 2) 129 return FloatType::getF32(ctx); 130 if (scale == 4) 131 return FloatType::getF64(ctx); 132 } 133 if (isF32()) 134 if (scale == 2) 135 return FloatType::getF64(ctx); 136 return FloatType(); 137 } 138 139 //===----------------------------------------------------------------------===// 140 // FunctionType 141 //===----------------------------------------------------------------------===// 142 143 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 144 145 ArrayRef<Type> FunctionType::getInputs() const { 146 return getImpl()->getInputs(); 147 } 148 149 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 150 151 ArrayRef<Type> FunctionType::getResults() const { 152 return getImpl()->getResults(); 153 } 154 155 /// Helper to call a callback once on each index in the range 156 /// [0, `totalIndices`), *except* for the indices given in `indices`. 157 /// `indices` is allowed to have duplicates and can be in any order. 158 inline void iterateIndicesExcept(unsigned totalIndices, 159 ArrayRef<unsigned> indices, 160 function_ref<void(unsigned)> callback) { 161 llvm::BitVector skipIndices(totalIndices); 162 for (unsigned i : indices) 163 skipIndices.set(i); 164 165 for (unsigned i = 0; i < totalIndices; ++i) 166 if (!skipIndices.test(i)) 167 callback(i); 168 } 169 170 /// Returns a new function type with the specified arguments and results 171 /// inserted. 172 FunctionType FunctionType::getWithArgsAndResults( 173 ArrayRef<unsigned> argIndices, TypeRange argTypes, 174 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 175 assert(argIndices.size() == argTypes.size()); 176 assert(resultIndices.size() == resultTypes.size()); 177 178 ArrayRef<Type> newInputTypes = getInputs(); 179 SmallVector<Type, 4> newInputTypesBuffer; 180 if (!argIndices.empty()) { 181 const auto *fromIt = newInputTypes.begin(); 182 for (auto it : llvm::zip(argIndices, argTypes)) { 183 const auto *toIt = newInputTypes.begin() + std::get<0>(it); 184 newInputTypesBuffer.append(fromIt, toIt); 185 newInputTypesBuffer.push_back(std::get<1>(it)); 186 fromIt = toIt; 187 } 188 newInputTypesBuffer.append(fromIt, newInputTypes.end()); 189 newInputTypes = newInputTypesBuffer; 190 } 191 192 ArrayRef<Type> newResultTypes = getResults(); 193 SmallVector<Type, 4> newResultTypesBuffer; 194 if (!resultIndices.empty()) { 195 const auto *fromIt = newResultTypes.begin(); 196 for (auto it : llvm::zip(resultIndices, resultTypes)) { 197 const auto *toIt = newResultTypes.begin() + std::get<0>(it); 198 newResultTypesBuffer.append(fromIt, toIt); 199 newResultTypesBuffer.push_back(std::get<1>(it)); 200 fromIt = toIt; 201 } 202 newResultTypesBuffer.append(fromIt, newResultTypes.end()); 203 newResultTypes = newResultTypesBuffer; 204 } 205 206 return FunctionType::get(getContext(), newInputTypes, newResultTypes); 207 } 208 209 /// Returns a new function type without the specified arguments and results. 210 FunctionType 211 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 212 ArrayRef<unsigned> resultIndices) { 213 ArrayRef<Type> newInputTypes = getInputs(); 214 SmallVector<Type, 4> newInputTypesBuffer; 215 if (!argIndices.empty()) { 216 unsigned originalNumArgs = getNumInputs(); 217 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 218 newInputTypesBuffer.emplace_back(getInput(i)); 219 }); 220 newInputTypes = newInputTypesBuffer; 221 } 222 223 ArrayRef<Type> newResultTypes = getResults(); 224 SmallVector<Type, 4> newResultTypesBuffer; 225 if (!resultIndices.empty()) { 226 unsigned originalNumResults = getNumResults(); 227 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 228 newResultTypesBuffer.emplace_back(getResult(i)); 229 }); 230 newResultTypes = newResultTypesBuffer; 231 } 232 233 return get(getContext(), newInputTypes, newResultTypes); 234 } 235 236 void FunctionType::walkImmediateSubElements( 237 function_ref<void(Attribute)> walkAttrsFn, 238 function_ref<void(Type)> walkTypesFn) const { 239 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 240 walkTypesFn(type); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // OpaqueType 245 //===----------------------------------------------------------------------===// 246 247 /// Verify the construction of an opaque type. 248 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 249 StringAttr dialect, StringRef typeData) { 250 if (!Dialect::isValidNamespace(dialect.strref())) 251 return emitError() << "invalid dialect namespace '" << dialect << "'"; 252 253 // Check that the dialect is actually registered. 254 MLIRContext *context = dialect.getContext(); 255 if (!context->allowsUnregisteredDialects() && 256 !context->getLoadedDialect(dialect.strref())) { 257 return emitError() 258 << "`!" << dialect << "<\"" << typeData << "\">" 259 << "` type created with unregistered dialect. If this is " 260 "intended, please call allowUnregisteredDialects() on the " 261 "MLIRContext, or use -allow-unregistered-dialect with " 262 "the MLIR opt tool used"; 263 } 264 265 return success(); 266 } 267 268 //===----------------------------------------------------------------------===// 269 // VectorType 270 //===----------------------------------------------------------------------===// 271 272 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 273 ArrayRef<int64_t> shape, Type elementType, 274 unsigned numScalableDims) { 275 if (!isValidElementType(elementType)) 276 return emitError() 277 << "vector elements must be int/index/float type but got " 278 << elementType; 279 280 if (any_of(shape, [](int64_t i) { return i <= 0; })) 281 return emitError() 282 << "vector types must have positive constant sizes but got " 283 << shape; 284 285 return success(); 286 } 287 288 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 289 if (!scale) 290 return VectorType(); 291 if (auto et = getElementType().dyn_cast<IntegerType>()) 292 if (auto scaledEt = et.scaleElementBitwidth(scale)) 293 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 294 if (auto et = getElementType().dyn_cast<FloatType>()) 295 if (auto scaledEt = et.scaleElementBitwidth(scale)) 296 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 297 return VectorType(); 298 } 299 300 void VectorType::walkImmediateSubElements( 301 function_ref<void(Attribute)> walkAttrsFn, 302 function_ref<void(Type)> walkTypesFn) const { 303 walkTypesFn(getElementType()); 304 } 305 306 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 307 Type elementType) const { 308 return VectorType::get(shape.getValueOr(getShape()), elementType, 309 getNumScalableDims()); 310 } 311 312 //===----------------------------------------------------------------------===// 313 // TensorType 314 //===----------------------------------------------------------------------===// 315 316 Type TensorType::getElementType() const { 317 return llvm::TypeSwitch<TensorType, Type>(*this) 318 .Case<RankedTensorType, UnrankedTensorType>( 319 [](auto type) { return type.getElementType(); }); 320 } 321 322 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); } 323 324 ArrayRef<int64_t> TensorType::getShape() const { 325 return cast<RankedTensorType>().getShape(); 326 } 327 328 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 329 Type elementType) const { 330 if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) { 331 if (shape) 332 return RankedTensorType::get(*shape, elementType); 333 return UnrankedTensorType::get(elementType); 334 } 335 336 auto rankedTy = cast<RankedTensorType>(); 337 if (!shape) 338 return RankedTensorType::get(rankedTy.getShape(), elementType, 339 rankedTy.getEncoding()); 340 return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()), 341 elementType, rankedTy.getEncoding()); 342 } 343 344 // Check if "elementType" can be an element type of a tensor. 345 static LogicalResult 346 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 347 Type elementType) { 348 if (!TensorType::isValidElementType(elementType)) 349 return emitError() << "invalid tensor element type: " << elementType; 350 return success(); 351 } 352 353 /// Return true if the specified element type is ok in a tensor. 354 bool TensorType::isValidElementType(Type type) { 355 // Note: Non standard/builtin types are allowed to exist within tensor 356 // types. Dialects are expected to verify that tensor types have a valid 357 // element type within that dialect. 358 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 359 IndexType>() || 360 !llvm::isa<BuiltinDialect>(type.getDialect()); 361 } 362 363 //===----------------------------------------------------------------------===// 364 // RankedTensorType 365 //===----------------------------------------------------------------------===// 366 367 LogicalResult 368 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 369 ArrayRef<int64_t> shape, Type elementType, 370 Attribute encoding) { 371 for (int64_t s : shape) 372 if (s < -1) 373 return emitError() << "invalid tensor dimension size"; 374 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 375 if (failed(v.verifyEncoding(shape, elementType, emitError))) 376 return failure(); 377 return checkTensorElementType(emitError, elementType); 378 } 379 380 void RankedTensorType::walkImmediateSubElements( 381 function_ref<void(Attribute)> walkAttrsFn, 382 function_ref<void(Type)> walkTypesFn) const { 383 walkTypesFn(getElementType()); 384 if (Attribute encoding = getEncoding()) 385 walkAttrsFn(encoding); 386 } 387 388 //===----------------------------------------------------------------------===// 389 // UnrankedTensorType 390 //===----------------------------------------------------------------------===// 391 392 LogicalResult 393 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 394 Type elementType) { 395 return checkTensorElementType(emitError, elementType); 396 } 397 398 void UnrankedTensorType::walkImmediateSubElements( 399 function_ref<void(Attribute)> walkAttrsFn, 400 function_ref<void(Type)> walkTypesFn) const { 401 walkTypesFn(getElementType()); 402 } 403 404 //===----------------------------------------------------------------------===// 405 // BaseMemRefType 406 //===----------------------------------------------------------------------===// 407 408 Type BaseMemRefType::getElementType() const { 409 return llvm::TypeSwitch<BaseMemRefType, Type>(*this) 410 .Case<MemRefType, UnrankedMemRefType>( 411 [](auto type) { return type.getElementType(); }); 412 } 413 414 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); } 415 416 ArrayRef<int64_t> BaseMemRefType::getShape() const { 417 return cast<MemRefType>().getShape(); 418 } 419 420 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape, 421 Type elementType) const { 422 if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) { 423 if (!shape) 424 return UnrankedMemRefType::get(elementType, getMemorySpace()); 425 MemRefType::Builder builder(*shape, elementType); 426 builder.setMemorySpace(getMemorySpace()); 427 return builder; 428 } 429 430 MemRefType::Builder builder(cast<MemRefType>()); 431 if (shape) 432 builder.setShape(*shape); 433 builder.setElementType(elementType); 434 return builder; 435 } 436 437 Attribute BaseMemRefType::getMemorySpace() const { 438 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 439 return rankedMemRefTy.getMemorySpace(); 440 return cast<UnrankedMemRefType>().getMemorySpace(); 441 } 442 443 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 444 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 445 return rankedMemRefTy.getMemorySpaceAsInt(); 446 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 447 } 448 449 //===----------------------------------------------------------------------===// 450 // MemRefType 451 //===----------------------------------------------------------------------===// 452 453 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 454 /// `originalShape` with some `1` entries erased, return the set of indices 455 /// that specifies which of the entries of `originalShape` are dropped to obtain 456 /// `reducedShape`. The returned mask can be applied as a projection to 457 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 458 /// which dimensions must be kept when e.g. compute MemRef strides under 459 /// rank-reducing operations. Return None if reducedShape cannot be obtained 460 /// by dropping only `1` entries in `originalShape`. 461 llvm::Optional<llvm::SmallDenseSet<unsigned>> 462 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 463 ArrayRef<int64_t> reducedShape) { 464 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 465 llvm::SmallDenseSet<unsigned> unusedDims; 466 unsigned reducedIdx = 0; 467 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 468 // Greedily insert `originalIdx` if match. 469 if (reducedIdx < reducedRank && 470 originalShape[originalIdx] == reducedShape[reducedIdx]) { 471 reducedIdx++; 472 continue; 473 } 474 475 unusedDims.insert(originalIdx); 476 // If no match on `originalIdx`, the `originalShape` at this dimension 477 // must be 1, otherwise we bail. 478 if (originalShape[originalIdx] != 1) 479 return llvm::None; 480 } 481 // The whole reducedShape must be scanned, otherwise we bail. 482 if (reducedIdx != reducedRank) 483 return llvm::None; 484 return unusedDims; 485 } 486 487 SliceVerificationResult 488 mlir::isRankReducedType(ShapedType originalType, 489 ShapedType candidateReducedType) { 490 if (originalType == candidateReducedType) 491 return SliceVerificationResult::Success; 492 493 ShapedType originalShapedType = originalType.cast<ShapedType>(); 494 ShapedType candidateReducedShapedType = 495 candidateReducedType.cast<ShapedType>(); 496 497 // Rank and size logic is valid for all ShapedTypes. 498 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 499 ArrayRef<int64_t> candidateReducedShape = 500 candidateReducedShapedType.getShape(); 501 unsigned originalRank = originalShape.size(), 502 candidateReducedRank = candidateReducedShape.size(); 503 if (candidateReducedRank > originalRank) 504 return SliceVerificationResult::RankTooLarge; 505 506 auto optionalUnusedDimsMask = 507 computeRankReductionMask(originalShape, candidateReducedShape); 508 509 // Sizes cannot be matched in case empty vector is returned. 510 if (!optionalUnusedDimsMask.hasValue()) 511 return SliceVerificationResult::SizeMismatch; 512 513 if (originalShapedType.getElementType() != 514 candidateReducedShapedType.getElementType()) 515 return SliceVerificationResult::ElemTypeMismatch; 516 517 return SliceVerificationResult::Success; 518 } 519 520 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 521 // Empty attribute is allowed as default memory space. 522 if (!memorySpace) 523 return true; 524 525 // Supported built-in attributes. 526 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 527 return true; 528 529 // Allow custom dialect attributes. 530 if (!isa<BuiltinDialect>(memorySpace.getDialect())) 531 return true; 532 533 return false; 534 } 535 536 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 537 MLIRContext *ctx) { 538 if (memorySpace == 0) 539 return nullptr; 540 541 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 542 } 543 544 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 545 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 546 if (intMemorySpace && intMemorySpace.getValue() == 0) 547 return nullptr; 548 549 return memorySpace; 550 } 551 552 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 553 if (!memorySpace) 554 return 0; 555 556 assert(memorySpace.isa<IntegerAttr>() && 557 "Using `getMemorySpaceInteger` with non-Integer attribute"); 558 559 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 560 } 561 562 MemRefType::Builder & 563 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 564 memorySpace = 565 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 566 return *this; 567 } 568 569 unsigned MemRefType::getMemorySpaceAsInt() const { 570 return detail::getMemorySpaceAsInt(getMemorySpace()); 571 } 572 573 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 574 MemRefLayoutAttrInterface layout, 575 Attribute memorySpace) { 576 // Use default layout for empty attribute. 577 if (!layout) 578 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 579 shape.size(), elementType.getContext())); 580 581 // Drop default memory space value and replace it with empty attribute. 582 memorySpace = skipDefaultMemorySpace(memorySpace); 583 584 return Base::get(elementType.getContext(), shape, elementType, layout, 585 memorySpace); 586 } 587 588 MemRefType MemRefType::getChecked( 589 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 590 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 591 592 // Use default layout for empty attribute. 593 if (!layout) 594 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 595 shape.size(), elementType.getContext())); 596 597 // Drop default memory space value and replace it with empty attribute. 598 memorySpace = skipDefaultMemorySpace(memorySpace); 599 600 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 601 elementType, layout, memorySpace); 602 } 603 604 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 605 AffineMap map, Attribute memorySpace) { 606 607 // Use default layout for empty map. 608 if (!map) 609 map = AffineMap::getMultiDimIdentityMap(shape.size(), 610 elementType.getContext()); 611 612 // Wrap AffineMap into Attribute. 613 Attribute layout = AffineMapAttr::get(map); 614 615 // Drop default memory space value and replace it with empty attribute. 616 memorySpace = skipDefaultMemorySpace(memorySpace); 617 618 return Base::get(elementType.getContext(), shape, elementType, layout, 619 memorySpace); 620 } 621 622 MemRefType 623 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 624 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 625 Attribute memorySpace) { 626 627 // Use default layout for empty map. 628 if (!map) 629 map = AffineMap::getMultiDimIdentityMap(shape.size(), 630 elementType.getContext()); 631 632 // Wrap AffineMap into Attribute. 633 Attribute layout = AffineMapAttr::get(map); 634 635 // Drop default memory space value and replace it with empty attribute. 636 memorySpace = skipDefaultMemorySpace(memorySpace); 637 638 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 639 elementType, layout, memorySpace); 640 } 641 642 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 643 AffineMap map, unsigned memorySpaceInd) { 644 645 // Use default layout for empty map. 646 if (!map) 647 map = AffineMap::getMultiDimIdentityMap(shape.size(), 648 elementType.getContext()); 649 650 // Wrap AffineMap into Attribute. 651 Attribute layout = AffineMapAttr::get(map); 652 653 // Convert deprecated integer-like memory space to Attribute. 654 Attribute memorySpace = 655 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 656 657 return Base::get(elementType.getContext(), shape, elementType, layout, 658 memorySpace); 659 } 660 661 MemRefType 662 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 663 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 664 unsigned memorySpaceInd) { 665 666 // Use default layout for empty map. 667 if (!map) 668 map = AffineMap::getMultiDimIdentityMap(shape.size(), 669 elementType.getContext()); 670 671 // Wrap AffineMap into Attribute. 672 Attribute layout = AffineMapAttr::get(map); 673 674 // Convert deprecated integer-like memory space to Attribute. 675 Attribute memorySpace = 676 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 677 678 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 679 elementType, layout, memorySpace); 680 } 681 682 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 683 ArrayRef<int64_t> shape, Type elementType, 684 MemRefLayoutAttrInterface layout, 685 Attribute memorySpace) { 686 if (!BaseMemRefType::isValidElementType(elementType)) 687 return emitError() << "invalid memref element type"; 688 689 // Negative sizes are not allowed except for `-1` that means dynamic size. 690 for (int64_t s : shape) 691 if (s < -1) 692 return emitError() << "invalid memref size"; 693 694 assert(layout && "missing layout specification"); 695 if (failed(layout.verifyLayout(shape, emitError))) 696 return failure(); 697 698 if (!isSupportedMemorySpace(memorySpace)) 699 return emitError() << "unsupported memory space Attribute"; 700 701 return success(); 702 } 703 704 void MemRefType::walkImmediateSubElements( 705 function_ref<void(Attribute)> walkAttrsFn, 706 function_ref<void(Type)> walkTypesFn) const { 707 walkTypesFn(getElementType()); 708 if (!getLayout().isIdentity()) 709 walkAttrsFn(getLayout()); 710 walkAttrsFn(getMemorySpace()); 711 } 712 713 //===----------------------------------------------------------------------===// 714 // UnrankedMemRefType 715 //===----------------------------------------------------------------------===// 716 717 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 718 return detail::getMemorySpaceAsInt(getMemorySpace()); 719 } 720 721 LogicalResult 722 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 723 Type elementType, Attribute memorySpace) { 724 if (!BaseMemRefType::isValidElementType(elementType)) 725 return emitError() << "invalid memref element type"; 726 727 if (!isSupportedMemorySpace(memorySpace)) 728 return emitError() << "unsupported memory space Attribute"; 729 730 return success(); 731 } 732 733 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 734 // i.e. single term). Accumulate the AffineExpr into the existing one. 735 static void extractStridesFromTerm(AffineExpr e, 736 AffineExpr multiplicativeFactor, 737 MutableArrayRef<AffineExpr> strides, 738 AffineExpr &offset) { 739 if (auto dim = e.dyn_cast<AffineDimExpr>()) 740 strides[dim.getPosition()] = 741 strides[dim.getPosition()] + multiplicativeFactor; 742 else 743 offset = offset + e * multiplicativeFactor; 744 } 745 746 /// Takes a single AffineExpr `e` and populates the `strides` array with the 747 /// strides expressions for each dim position. 748 /// The convention is that the strides for dimensions d0, .. dn appear in 749 /// order to make indexing intuitive into the result. 750 static LogicalResult extractStrides(AffineExpr e, 751 AffineExpr multiplicativeFactor, 752 MutableArrayRef<AffineExpr> strides, 753 AffineExpr &offset) { 754 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 755 if (!bin) { 756 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 757 return success(); 758 } 759 760 if (bin.getKind() == AffineExprKind::CeilDiv || 761 bin.getKind() == AffineExprKind::FloorDiv || 762 bin.getKind() == AffineExprKind::Mod) 763 return failure(); 764 765 if (bin.getKind() == AffineExprKind::Mul) { 766 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 767 if (dim) { 768 strides[dim.getPosition()] = 769 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 770 return success(); 771 } 772 // LHS and RHS may both contain complex expressions of dims. Try one path 773 // and if it fails try the other. This is guaranteed to succeed because 774 // only one path may have a `dim`, otherwise this is not an AffineExpr in 775 // the first place. 776 if (bin.getLHS().isSymbolicOrConstant()) 777 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 778 strides, offset); 779 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 780 strides, offset); 781 } 782 783 if (bin.getKind() == AffineExprKind::Add) { 784 auto res1 = 785 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 786 auto res2 = 787 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 788 return success(succeeded(res1) && succeeded(res2)); 789 } 790 791 llvm_unreachable("unexpected binary operation"); 792 } 793 794 LogicalResult mlir::getStridesAndOffset(MemRefType t, 795 SmallVectorImpl<AffineExpr> &strides, 796 AffineExpr &offset) { 797 AffineMap m = t.getLayout().getAffineMap(); 798 799 if (m.getNumResults() != 1 && !m.isIdentity()) 800 return failure(); 801 802 auto zero = getAffineConstantExpr(0, t.getContext()); 803 auto one = getAffineConstantExpr(1, t.getContext()); 804 offset = zero; 805 strides.assign(t.getRank(), zero); 806 807 // Canonical case for empty map. 808 if (m.isIdentity()) { 809 // 0-D corner case, offset is already 0. 810 if (t.getRank() == 0) 811 return success(); 812 auto stridedExpr = 813 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 814 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 815 return success(); 816 assert(false && "unexpected failure: extract strides in canonical layout"); 817 } 818 819 // Non-canonical case requires more work. 820 auto stridedExpr = 821 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 822 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 823 offset = AffineExpr(); 824 strides.clear(); 825 return failure(); 826 } 827 828 // Simplify results to allow folding to constants and simple checks. 829 unsigned numDims = m.getNumDims(); 830 unsigned numSymbols = m.getNumSymbols(); 831 offset = simplifyAffineExpr(offset, numDims, numSymbols); 832 for (auto &stride : strides) 833 stride = simplifyAffineExpr(stride, numDims, numSymbols); 834 835 /// In practice, a strided memref must be internally non-aliasing. Test 836 /// against 0 as a proxy. 837 /// TODO: static cases can have more advanced checks. 838 /// TODO: dynamic cases would require a way to compare symbolic 839 /// expressions and would probably need an affine set context propagated 840 /// everywhere. 841 if (llvm::any_of(strides, [](AffineExpr e) { 842 return e == getAffineConstantExpr(0, e.getContext()); 843 })) { 844 offset = AffineExpr(); 845 strides.clear(); 846 return failure(); 847 } 848 849 return success(); 850 } 851 852 LogicalResult mlir::getStridesAndOffset(MemRefType t, 853 SmallVectorImpl<int64_t> &strides, 854 int64_t &offset) { 855 AffineExpr offsetExpr; 856 SmallVector<AffineExpr, 4> strideExprs; 857 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 858 return failure(); 859 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 860 offset = cst.getValue(); 861 else 862 offset = ShapedType::kDynamicStrideOrOffset; 863 for (auto e : strideExprs) { 864 if (auto c = e.dyn_cast<AffineConstantExpr>()) 865 strides.push_back(c.getValue()); 866 else 867 strides.push_back(ShapedType::kDynamicStrideOrOffset); 868 } 869 return success(); 870 } 871 872 void UnrankedMemRefType::walkImmediateSubElements( 873 function_ref<void(Attribute)> walkAttrsFn, 874 function_ref<void(Type)> walkTypesFn) const { 875 walkTypesFn(getElementType()); 876 walkAttrsFn(getMemorySpace()); 877 } 878 879 //===----------------------------------------------------------------------===// 880 /// TupleType 881 //===----------------------------------------------------------------------===// 882 883 /// Return the elements types for this tuple. 884 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 885 886 /// Accumulate the types contained in this tuple and tuples nested within it. 887 /// Note that this only flattens nested tuples, not any other container type, 888 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 889 /// (i32, tensor<i32>, f32, i64) 890 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 891 for (Type type : getTypes()) { 892 if (auto nestedTuple = type.dyn_cast<TupleType>()) 893 nestedTuple.getFlattenedTypes(types); 894 else 895 types.push_back(type); 896 } 897 } 898 899 /// Return the number of element types. 900 size_t TupleType::size() const { return getImpl()->size(); } 901 902 void TupleType::walkImmediateSubElements( 903 function_ref<void(Attribute)> walkAttrsFn, 904 function_ref<void(Type)> walkTypesFn) const { 905 for (Type type : getTypes()) 906 walkTypesFn(type); 907 } 908 909 //===----------------------------------------------------------------------===// 910 // Type Utilities 911 //===----------------------------------------------------------------------===// 912 913 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 914 int64_t offset, 915 MLIRContext *context) { 916 AffineExpr expr; 917 unsigned nSymbols = 0; 918 919 // AffineExpr for offset. 920 // Static case. 921 if (offset != MemRefType::getDynamicStrideOrOffset()) { 922 auto cst = getAffineConstantExpr(offset, context); 923 expr = cst; 924 } else { 925 // Dynamic case, new symbol for the offset. 926 auto sym = getAffineSymbolExpr(nSymbols++, context); 927 expr = sym; 928 } 929 930 // AffineExpr for strides. 931 for (const auto &en : llvm::enumerate(strides)) { 932 auto dim = en.index(); 933 auto stride = en.value(); 934 assert(stride != 0 && "Invalid stride specification"); 935 auto d = getAffineDimExpr(dim, context); 936 AffineExpr mult; 937 // Static case. 938 if (stride != MemRefType::getDynamicStrideOrOffset()) 939 mult = getAffineConstantExpr(stride, context); 940 else 941 // Dynamic case, new symbol for each new stride. 942 mult = getAffineSymbolExpr(nSymbols++, context); 943 expr = expr + d * mult; 944 } 945 946 return AffineMap::get(strides.size(), nSymbols, expr); 947 } 948 949 /// Return a version of `t` with identity layout if it can be determined 950 /// statically that the layout is the canonical contiguous strided layout. 951 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 952 /// `t` with simplified layout. 953 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 954 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 955 AffineMap m = t.getLayout().getAffineMap(); 956 957 // Already in canonical form. 958 if (m.isIdentity()) 959 return t; 960 961 // Can't reduce to canonical identity form, return in canonical form. 962 if (m.getNumResults() > 1) 963 return t; 964 965 // Corner-case for 0-D affine maps. 966 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 967 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 968 if (cst.getValue() == 0) 969 return MemRefType::Builder(t).setLayout({}); 970 return t; 971 } 972 973 // 0-D corner case for empty shape that still have an affine map. Example: 974 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 975 // offset needs to remain, just return t. 976 if (t.getShape().empty()) 977 return t; 978 979 // If the canonical strided layout for the sizes of `t` is equal to the 980 // simplified layout of `t` we can just return an empty layout. Otherwise, 981 // just simplify the existing layout. 982 AffineExpr expr = 983 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 984 auto simplifiedLayoutExpr = 985 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 986 if (expr != simplifiedLayoutExpr) 987 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( 988 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); 989 return MemRefType::Builder(t).setLayout({}); 990 } 991 992 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 993 ArrayRef<AffineExpr> exprs, 994 MLIRContext *context) { 995 assert(!sizes.empty() && !exprs.empty() && 996 "expected non-empty sizes and exprs"); 997 998 // Size 0 corner case is useful for canonicalizations. 999 if (llvm::is_contained(sizes, 0)) 1000 return getAffineConstantExpr(0, context); 1001 1002 auto maps = AffineMap::inferFromExprList(exprs); 1003 assert(!maps.empty() && "Expected one non-empty map"); 1004 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 1005 1006 AffineExpr expr; 1007 bool dynamicPoisonBit = false; 1008 int64_t runningSize = 1; 1009 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 1010 int64_t size = std::get<1>(en); 1011 // Degenerate case, no size =-> no stride 1012 if (size == 0) 1013 continue; 1014 AffineExpr dimExpr = std::get<0>(en); 1015 AffineExpr stride = dynamicPoisonBit 1016 ? getAffineSymbolExpr(nSymbols++, context) 1017 : getAffineConstantExpr(runningSize, context); 1018 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 1019 if (size > 0) { 1020 runningSize *= size; 1021 assert(runningSize > 0 && "integer overflow in size computation"); 1022 } else { 1023 dynamicPoisonBit = true; 1024 } 1025 } 1026 return simplifyAffineExpr(expr, numDims, nSymbols); 1027 } 1028 1029 /// Return a version of `t` with a layout that has all dynamic offset and 1030 /// strides. This is used to erase the static layout. 1031 MemRefType mlir::eraseStridedLayout(MemRefType t) { 1032 auto val = ShapedType::kDynamicStrideOrOffset; 1033 return MemRefType::Builder(t).setLayout( 1034 AffineMapAttr::get(makeStridedLinearLayoutMap( 1035 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()))); 1036 } 1037 1038 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1039 MLIRContext *context) { 1040 SmallVector<AffineExpr, 4> exprs; 1041 exprs.reserve(sizes.size()); 1042 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 1043 exprs.push_back(getAffineDimExpr(dim, context)); 1044 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 1045 } 1046 1047 /// Return true if the layout for `t` is compatible with strided semantics. 1048 bool mlir::isStrided(MemRefType t) { 1049 int64_t offset; 1050 SmallVector<int64_t, 4> strides; 1051 auto res = getStridesAndOffset(t, strides, offset); 1052 return succeeded(res); 1053 } 1054 1055 /// Return the layout map in strided linear layout AffineMap form. 1056 /// Return null if the layout is not compatible with a strided layout. 1057 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1058 int64_t offset; 1059 SmallVector<int64_t, 4> strides; 1060 if (failed(getStridesAndOffset(t, strides, offset))) 1061 return AffineMap(); 1062 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1063 } 1064 1065 /// Return the AffineExpr representation of the offset, assuming `memRefType` 1066 /// is a strided memref. 1067 static AffineExpr getOffsetExpr(MemRefType memrefType) { 1068 SmallVector<AffineExpr> strides; 1069 AffineExpr offset; 1070 if (failed(getStridesAndOffset(memrefType, strides, offset))) 1071 assert(false && "expected strided memref"); 1072 return offset; 1073 } 1074 1075 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and 1076 /// `offset` AffineExpr. 1077 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, 1078 ArrayRef<int64_t> shape, 1079 Type elementType, 1080 AffineExpr offset) { 1081 AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); 1082 AffineExpr contiguousRowMajor = canonical + offset; 1083 AffineMap contiguousRowMajorMap = 1084 AffineMap::inferFromExprList({contiguousRowMajor})[0]; 1085 return MemRefType::get(shape, elementType, contiguousRowMajorMap); 1086 } 1087 1088 /// Helper determining if a memref is static-shape and contiguous-row-major 1089 /// layout, while still allowing for an arbitrary offset (any static or 1090 /// dynamic value). 1091 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { 1092 if (!memrefType.hasStaticShape()) 1093 return false; 1094 AffineExpr offset = getOffsetExpr(memrefType); 1095 MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( 1096 memrefType.getContext(), memrefType.getShape(), 1097 memrefType.getElementType(), offset); 1098 return canonicalizeStridedLayout(memrefType) == 1099 canonicalizeStridedLayout(contiguousRowMajorMemRefType); 1100 } 1101