1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/FunctionInterfaces.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/TensorEncoding.h" 20 #include "llvm/ADT/APFloat.h" 21 #include "llvm/ADT/BitVector.h" 22 #include "llvm/ADT/Sequence.h" 23 #include "llvm/ADT/Twine.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::detail; 28 29 //===----------------------------------------------------------------------===// 30 /// Tablegen Type Definitions 31 //===----------------------------------------------------------------------===// 32 33 #define GET_TYPEDEF_CLASSES 34 #include "mlir/IR/BuiltinTypes.cpp.inc" 35 36 //===----------------------------------------------------------------------===// 37 // BuiltinDialect 38 //===----------------------------------------------------------------------===// 39 40 void BuiltinDialect::registerTypes() { 41 addTypes< 42 #define GET_TYPEDEF_LIST 43 #include "mlir/IR/BuiltinTypes.cpp.inc" 44 >(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 /// ComplexType 49 //===----------------------------------------------------------------------===// 50 51 /// Verify the construction of an integer type. 52 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 53 Type elementType) { 54 if (!elementType.isIntOrFloat()) 55 return emitError() << "invalid element type for complex"; 56 return success(); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // Integer Type 61 //===----------------------------------------------------------------------===// 62 63 // static constexpr must have a definition (until in C++17 and inline variable). 64 constexpr unsigned IntegerType::kMaxWidth; 65 66 /// Verify the construction of an integer type. 67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 68 unsigned width, 69 SignednessSemantics signedness) { 70 if (width > IntegerType::kMaxWidth) { 71 return emitError() << "integer bitwidth is limited to " 72 << IntegerType::kMaxWidth << " bits"; 73 } 74 return success(); 75 } 76 77 unsigned IntegerType::getWidth() const { return getImpl()->width; } 78 79 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 80 return getImpl()->signedness; 81 } 82 83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 84 if (!scale) 85 return IntegerType(); 86 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Float Type 91 //===----------------------------------------------------------------------===// 92 93 unsigned FloatType::getWidth() { 94 if (isa<Float16Type, BFloat16Type>()) 95 return 16; 96 if (isa<Float32Type>()) 97 return 32; 98 if (isa<Float64Type>()) 99 return 64; 100 if (isa<Float80Type>()) 101 return 80; 102 if (isa<Float128Type>()) 103 return 128; 104 llvm_unreachable("unexpected float type"); 105 } 106 107 /// Returns the floating semantics for the given type. 108 const llvm::fltSemantics &FloatType::getFloatSemantics() { 109 if (isa<BFloat16Type>()) 110 return APFloat::BFloat(); 111 if (isa<Float16Type>()) 112 return APFloat::IEEEhalf(); 113 if (isa<Float32Type>()) 114 return APFloat::IEEEsingle(); 115 if (isa<Float64Type>()) 116 return APFloat::IEEEdouble(); 117 if (isa<Float80Type>()) 118 return APFloat::x87DoubleExtended(); 119 if (isa<Float128Type>()) 120 return APFloat::IEEEquad(); 121 llvm_unreachable("non-floating point type used"); 122 } 123 124 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 125 if (!scale) 126 return FloatType(); 127 MLIRContext *ctx = getContext(); 128 if (isF16() || isBF16()) { 129 if (scale == 2) 130 return FloatType::getF32(ctx); 131 if (scale == 4) 132 return FloatType::getF64(ctx); 133 } 134 if (isF32()) 135 if (scale == 2) 136 return FloatType::getF64(ctx); 137 return FloatType(); 138 } 139 140 //===----------------------------------------------------------------------===// 141 // FunctionType 142 //===----------------------------------------------------------------------===// 143 144 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 145 146 ArrayRef<Type> FunctionType::getInputs() const { 147 return getImpl()->getInputs(); 148 } 149 150 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 151 152 ArrayRef<Type> FunctionType::getResults() const { 153 return getImpl()->getResults(); 154 } 155 156 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { 157 return get(getContext(), inputs, results); 158 } 159 160 /// Returns a new function type with the specified arguments and results 161 /// inserted. 162 FunctionType FunctionType::getWithArgsAndResults( 163 ArrayRef<unsigned> argIndices, TypeRange argTypes, 164 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 165 SmallVector<Type> argStorage, resultStorage; 166 TypeRange newArgTypes = function_interface_impl::insertTypesInto( 167 getInputs(), argIndices, argTypes, argStorage); 168 TypeRange newResultTypes = function_interface_impl::insertTypesInto( 169 getResults(), resultIndices, resultTypes, resultStorage); 170 return clone(newArgTypes, newResultTypes); 171 } 172 173 /// Returns a new function type without the specified arguments and results. 174 FunctionType 175 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 176 ArrayRef<unsigned> resultIndices) { 177 SmallVector<Type> argStorage, resultStorage; 178 TypeRange newArgTypes = function_interface_impl::filterTypesOut( 179 getInputs(), argIndices, argStorage); 180 TypeRange newResultTypes = function_interface_impl::filterTypesOut( 181 getResults(), resultIndices, resultStorage); 182 return clone(newArgTypes, newResultTypes); 183 } 184 185 void FunctionType::walkImmediateSubElements( 186 function_ref<void(Attribute)> walkAttrsFn, 187 function_ref<void(Type)> walkTypesFn) const { 188 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 189 walkTypesFn(type); 190 } 191 192 //===----------------------------------------------------------------------===// 193 // OpaqueType 194 //===----------------------------------------------------------------------===// 195 196 /// Verify the construction of an opaque type. 197 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 198 StringAttr dialect, StringRef typeData) { 199 if (!Dialect::isValidNamespace(dialect.strref())) 200 return emitError() << "invalid dialect namespace '" << dialect << "'"; 201 202 // Check that the dialect is actually registered. 203 MLIRContext *context = dialect.getContext(); 204 if (!context->allowsUnregisteredDialects() && 205 !context->getLoadedDialect(dialect.strref())) { 206 return emitError() 207 << "`!" << dialect << "<\"" << typeData << "\">" 208 << "` type created with unregistered dialect. If this is " 209 "intended, please call allowUnregisteredDialects() on the " 210 "MLIRContext, or use -allow-unregistered-dialect with " 211 "the MLIR opt tool used"; 212 } 213 214 return success(); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // VectorType 219 //===----------------------------------------------------------------------===// 220 221 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 222 ArrayRef<int64_t> shape, Type elementType, 223 unsigned numScalableDims) { 224 if (!isValidElementType(elementType)) 225 return emitError() 226 << "vector elements must be int/index/float type but got " 227 << elementType; 228 229 if (any_of(shape, [](int64_t i) { return i <= 0; })) 230 return emitError() 231 << "vector types must have positive constant sizes but got " 232 << shape; 233 234 return success(); 235 } 236 237 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 238 if (!scale) 239 return VectorType(); 240 if (auto et = getElementType().dyn_cast<IntegerType>()) 241 if (auto scaledEt = et.scaleElementBitwidth(scale)) 242 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 243 if (auto et = getElementType().dyn_cast<FloatType>()) 244 if (auto scaledEt = et.scaleElementBitwidth(scale)) 245 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 246 return VectorType(); 247 } 248 249 void VectorType::walkImmediateSubElements( 250 function_ref<void(Attribute)> walkAttrsFn, 251 function_ref<void(Type)> walkTypesFn) const { 252 walkTypesFn(getElementType()); 253 } 254 255 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 256 Type elementType) const { 257 return VectorType::get(shape.getValueOr(getShape()), elementType, 258 getNumScalableDims()); 259 } 260 261 //===----------------------------------------------------------------------===// 262 // TensorType 263 //===----------------------------------------------------------------------===// 264 265 Type TensorType::getElementType() const { 266 return llvm::TypeSwitch<TensorType, Type>(*this) 267 .Case<RankedTensorType, UnrankedTensorType>( 268 [](auto type) { return type.getElementType(); }); 269 } 270 271 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); } 272 273 ArrayRef<int64_t> TensorType::getShape() const { 274 return cast<RankedTensorType>().getShape(); 275 } 276 277 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 278 Type elementType) const { 279 if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) { 280 if (shape) 281 return RankedTensorType::get(*shape, elementType); 282 return UnrankedTensorType::get(elementType); 283 } 284 285 auto rankedTy = cast<RankedTensorType>(); 286 if (!shape) 287 return RankedTensorType::get(rankedTy.getShape(), elementType, 288 rankedTy.getEncoding()); 289 return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()), 290 elementType, rankedTy.getEncoding()); 291 } 292 293 // Check if "elementType" can be an element type of a tensor. 294 static LogicalResult 295 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 296 Type elementType) { 297 if (!TensorType::isValidElementType(elementType)) 298 return emitError() << "invalid tensor element type: " << elementType; 299 return success(); 300 } 301 302 /// Return true if the specified element type is ok in a tensor. 303 bool TensorType::isValidElementType(Type type) { 304 // Note: Non standard/builtin types are allowed to exist within tensor 305 // types. Dialects are expected to verify that tensor types have a valid 306 // element type within that dialect. 307 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 308 IndexType>() || 309 !llvm::isa<BuiltinDialect>(type.getDialect()); 310 } 311 312 //===----------------------------------------------------------------------===// 313 // RankedTensorType 314 //===----------------------------------------------------------------------===// 315 316 LogicalResult 317 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 318 ArrayRef<int64_t> shape, Type elementType, 319 Attribute encoding) { 320 for (int64_t s : shape) 321 if (s < -1) 322 return emitError() << "invalid tensor dimension size"; 323 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 324 if (failed(v.verifyEncoding(shape, elementType, emitError))) 325 return failure(); 326 return checkTensorElementType(emitError, elementType); 327 } 328 329 void RankedTensorType::walkImmediateSubElements( 330 function_ref<void(Attribute)> walkAttrsFn, 331 function_ref<void(Type)> walkTypesFn) const { 332 walkTypesFn(getElementType()); 333 if (Attribute encoding = getEncoding()) 334 walkAttrsFn(encoding); 335 } 336 337 //===----------------------------------------------------------------------===// 338 // UnrankedTensorType 339 //===----------------------------------------------------------------------===// 340 341 LogicalResult 342 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 343 Type elementType) { 344 return checkTensorElementType(emitError, elementType); 345 } 346 347 void UnrankedTensorType::walkImmediateSubElements( 348 function_ref<void(Attribute)> walkAttrsFn, 349 function_ref<void(Type)> walkTypesFn) const { 350 walkTypesFn(getElementType()); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // BaseMemRefType 355 //===----------------------------------------------------------------------===// 356 357 Type BaseMemRefType::getElementType() const { 358 return llvm::TypeSwitch<BaseMemRefType, Type>(*this) 359 .Case<MemRefType, UnrankedMemRefType>( 360 [](auto type) { return type.getElementType(); }); 361 } 362 363 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); } 364 365 ArrayRef<int64_t> BaseMemRefType::getShape() const { 366 return cast<MemRefType>().getShape(); 367 } 368 369 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape, 370 Type elementType) const { 371 if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) { 372 if (!shape) 373 return UnrankedMemRefType::get(elementType, getMemorySpace()); 374 MemRefType::Builder builder(*shape, elementType); 375 builder.setMemorySpace(getMemorySpace()); 376 return builder; 377 } 378 379 MemRefType::Builder builder(cast<MemRefType>()); 380 if (shape) 381 builder.setShape(*shape); 382 builder.setElementType(elementType); 383 return builder; 384 } 385 386 Attribute BaseMemRefType::getMemorySpace() const { 387 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 388 return rankedMemRefTy.getMemorySpace(); 389 return cast<UnrankedMemRefType>().getMemorySpace(); 390 } 391 392 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 393 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 394 return rankedMemRefTy.getMemorySpaceAsInt(); 395 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 396 } 397 398 //===----------------------------------------------------------------------===// 399 // MemRefType 400 //===----------------------------------------------------------------------===// 401 402 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 403 /// `originalShape` with some `1` entries erased, return the set of indices 404 /// that specifies which of the entries of `originalShape` are dropped to obtain 405 /// `reducedShape`. The returned mask can be applied as a projection to 406 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 407 /// which dimensions must be kept when e.g. compute MemRef strides under 408 /// rank-reducing operations. Return None if reducedShape cannot be obtained 409 /// by dropping only `1` entries in `originalShape`. 410 llvm::Optional<llvm::SmallDenseSet<unsigned>> 411 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 412 ArrayRef<int64_t> reducedShape) { 413 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 414 llvm::SmallDenseSet<unsigned> unusedDims; 415 unsigned reducedIdx = 0; 416 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 417 // Greedily insert `originalIdx` if match. 418 if (reducedIdx < reducedRank && 419 originalShape[originalIdx] == reducedShape[reducedIdx]) { 420 reducedIdx++; 421 continue; 422 } 423 424 unusedDims.insert(originalIdx); 425 // If no match on `originalIdx`, the `originalShape` at this dimension 426 // must be 1, otherwise we bail. 427 if (originalShape[originalIdx] != 1) 428 return llvm::None; 429 } 430 // The whole reducedShape must be scanned, otherwise we bail. 431 if (reducedIdx != reducedRank) 432 return llvm::None; 433 return unusedDims; 434 } 435 436 SliceVerificationResult 437 mlir::isRankReducedType(ShapedType originalType, 438 ShapedType candidateReducedType) { 439 if (originalType == candidateReducedType) 440 return SliceVerificationResult::Success; 441 442 ShapedType originalShapedType = originalType.cast<ShapedType>(); 443 ShapedType candidateReducedShapedType = 444 candidateReducedType.cast<ShapedType>(); 445 446 // Rank and size logic is valid for all ShapedTypes. 447 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 448 ArrayRef<int64_t> candidateReducedShape = 449 candidateReducedShapedType.getShape(); 450 unsigned originalRank = originalShape.size(), 451 candidateReducedRank = candidateReducedShape.size(); 452 if (candidateReducedRank > originalRank) 453 return SliceVerificationResult::RankTooLarge; 454 455 auto optionalUnusedDimsMask = 456 computeRankReductionMask(originalShape, candidateReducedShape); 457 458 // Sizes cannot be matched in case empty vector is returned. 459 if (!optionalUnusedDimsMask.hasValue()) 460 return SliceVerificationResult::SizeMismatch; 461 462 if (originalShapedType.getElementType() != 463 candidateReducedShapedType.getElementType()) 464 return SliceVerificationResult::ElemTypeMismatch; 465 466 return SliceVerificationResult::Success; 467 } 468 469 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 470 // Empty attribute is allowed as default memory space. 471 if (!memorySpace) 472 return true; 473 474 // Supported built-in attributes. 475 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 476 return true; 477 478 // Allow custom dialect attributes. 479 if (!isa<BuiltinDialect>(memorySpace.getDialect())) 480 return true; 481 482 return false; 483 } 484 485 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 486 MLIRContext *ctx) { 487 if (memorySpace == 0) 488 return nullptr; 489 490 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 491 } 492 493 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 494 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 495 if (intMemorySpace && intMemorySpace.getValue() == 0) 496 return nullptr; 497 498 return memorySpace; 499 } 500 501 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 502 if (!memorySpace) 503 return 0; 504 505 assert(memorySpace.isa<IntegerAttr>() && 506 "Using `getMemorySpaceInteger` with non-Integer attribute"); 507 508 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 509 } 510 511 MemRefType::Builder & 512 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 513 memorySpace = 514 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 515 return *this; 516 } 517 518 unsigned MemRefType::getMemorySpaceAsInt() const { 519 return detail::getMemorySpaceAsInt(getMemorySpace()); 520 } 521 522 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 523 MemRefLayoutAttrInterface layout, 524 Attribute memorySpace) { 525 // Use default layout for empty attribute. 526 if (!layout) 527 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 528 shape.size(), elementType.getContext())); 529 530 // Drop default memory space value and replace it with empty attribute. 531 memorySpace = skipDefaultMemorySpace(memorySpace); 532 533 return Base::get(elementType.getContext(), shape, elementType, layout, 534 memorySpace); 535 } 536 537 MemRefType MemRefType::getChecked( 538 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 539 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 540 541 // Use default layout for empty attribute. 542 if (!layout) 543 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 544 shape.size(), elementType.getContext())); 545 546 // Drop default memory space value and replace it with empty attribute. 547 memorySpace = skipDefaultMemorySpace(memorySpace); 548 549 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 550 elementType, layout, memorySpace); 551 } 552 553 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 554 AffineMap map, Attribute memorySpace) { 555 556 // Use default layout for empty map. 557 if (!map) 558 map = AffineMap::getMultiDimIdentityMap(shape.size(), 559 elementType.getContext()); 560 561 // Wrap AffineMap into Attribute. 562 Attribute layout = AffineMapAttr::get(map); 563 564 // Drop default memory space value and replace it with empty attribute. 565 memorySpace = skipDefaultMemorySpace(memorySpace); 566 567 return Base::get(elementType.getContext(), shape, elementType, layout, 568 memorySpace); 569 } 570 571 MemRefType 572 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 573 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 574 Attribute memorySpace) { 575 576 // Use default layout for empty map. 577 if (!map) 578 map = AffineMap::getMultiDimIdentityMap(shape.size(), 579 elementType.getContext()); 580 581 // Wrap AffineMap into Attribute. 582 Attribute layout = AffineMapAttr::get(map); 583 584 // Drop default memory space value and replace it with empty attribute. 585 memorySpace = skipDefaultMemorySpace(memorySpace); 586 587 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 588 elementType, layout, memorySpace); 589 } 590 591 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 592 AffineMap map, unsigned memorySpaceInd) { 593 594 // Use default layout for empty map. 595 if (!map) 596 map = AffineMap::getMultiDimIdentityMap(shape.size(), 597 elementType.getContext()); 598 599 // Wrap AffineMap into Attribute. 600 Attribute layout = AffineMapAttr::get(map); 601 602 // Convert deprecated integer-like memory space to Attribute. 603 Attribute memorySpace = 604 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 605 606 return Base::get(elementType.getContext(), shape, elementType, layout, 607 memorySpace); 608 } 609 610 MemRefType 611 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 612 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 613 unsigned memorySpaceInd) { 614 615 // Use default layout for empty map. 616 if (!map) 617 map = AffineMap::getMultiDimIdentityMap(shape.size(), 618 elementType.getContext()); 619 620 // Wrap AffineMap into Attribute. 621 Attribute layout = AffineMapAttr::get(map); 622 623 // Convert deprecated integer-like memory space to Attribute. 624 Attribute memorySpace = 625 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 626 627 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 628 elementType, layout, memorySpace); 629 } 630 631 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 632 ArrayRef<int64_t> shape, Type elementType, 633 MemRefLayoutAttrInterface layout, 634 Attribute memorySpace) { 635 if (!BaseMemRefType::isValidElementType(elementType)) 636 return emitError() << "invalid memref element type"; 637 638 // Negative sizes are not allowed except for `-1` that means dynamic size. 639 for (int64_t s : shape) 640 if (s < -1) 641 return emitError() << "invalid memref size"; 642 643 assert(layout && "missing layout specification"); 644 if (failed(layout.verifyLayout(shape, emitError))) 645 return failure(); 646 647 if (!isSupportedMemorySpace(memorySpace)) 648 return emitError() << "unsupported memory space Attribute"; 649 650 return success(); 651 } 652 653 void MemRefType::walkImmediateSubElements( 654 function_ref<void(Attribute)> walkAttrsFn, 655 function_ref<void(Type)> walkTypesFn) const { 656 walkTypesFn(getElementType()); 657 if (!getLayout().isIdentity()) 658 walkAttrsFn(getLayout()); 659 walkAttrsFn(getMemorySpace()); 660 } 661 662 //===----------------------------------------------------------------------===// 663 // UnrankedMemRefType 664 //===----------------------------------------------------------------------===// 665 666 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 667 return detail::getMemorySpaceAsInt(getMemorySpace()); 668 } 669 670 LogicalResult 671 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 672 Type elementType, Attribute memorySpace) { 673 if (!BaseMemRefType::isValidElementType(elementType)) 674 return emitError() << "invalid memref element type"; 675 676 if (!isSupportedMemorySpace(memorySpace)) 677 return emitError() << "unsupported memory space Attribute"; 678 679 return success(); 680 } 681 682 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 683 // i.e. single term). Accumulate the AffineExpr into the existing one. 684 static void extractStridesFromTerm(AffineExpr e, 685 AffineExpr multiplicativeFactor, 686 MutableArrayRef<AffineExpr> strides, 687 AffineExpr &offset) { 688 if (auto dim = e.dyn_cast<AffineDimExpr>()) 689 strides[dim.getPosition()] = 690 strides[dim.getPosition()] + multiplicativeFactor; 691 else 692 offset = offset + e * multiplicativeFactor; 693 } 694 695 /// Takes a single AffineExpr `e` and populates the `strides` array with the 696 /// strides expressions for each dim position. 697 /// The convention is that the strides for dimensions d0, .. dn appear in 698 /// order to make indexing intuitive into the result. 699 static LogicalResult extractStrides(AffineExpr e, 700 AffineExpr multiplicativeFactor, 701 MutableArrayRef<AffineExpr> strides, 702 AffineExpr &offset) { 703 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 704 if (!bin) { 705 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 706 return success(); 707 } 708 709 if (bin.getKind() == AffineExprKind::CeilDiv || 710 bin.getKind() == AffineExprKind::FloorDiv || 711 bin.getKind() == AffineExprKind::Mod) 712 return failure(); 713 714 if (bin.getKind() == AffineExprKind::Mul) { 715 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 716 if (dim) { 717 strides[dim.getPosition()] = 718 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 719 return success(); 720 } 721 // LHS and RHS may both contain complex expressions of dims. Try one path 722 // and if it fails try the other. This is guaranteed to succeed because 723 // only one path may have a `dim`, otherwise this is not an AffineExpr in 724 // the first place. 725 if (bin.getLHS().isSymbolicOrConstant()) 726 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 727 strides, offset); 728 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 729 strides, offset); 730 } 731 732 if (bin.getKind() == AffineExprKind::Add) { 733 auto res1 = 734 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 735 auto res2 = 736 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 737 return success(succeeded(res1) && succeeded(res2)); 738 } 739 740 llvm_unreachable("unexpected binary operation"); 741 } 742 743 LogicalResult mlir::getStridesAndOffset(MemRefType t, 744 SmallVectorImpl<AffineExpr> &strides, 745 AffineExpr &offset) { 746 AffineMap m = t.getLayout().getAffineMap(); 747 748 if (m.getNumResults() != 1 && !m.isIdentity()) 749 return failure(); 750 751 auto zero = getAffineConstantExpr(0, t.getContext()); 752 auto one = getAffineConstantExpr(1, t.getContext()); 753 offset = zero; 754 strides.assign(t.getRank(), zero); 755 756 // Canonical case for empty map. 757 if (m.isIdentity()) { 758 // 0-D corner case, offset is already 0. 759 if (t.getRank() == 0) 760 return success(); 761 auto stridedExpr = 762 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 763 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 764 return success(); 765 assert(false && "unexpected failure: extract strides in canonical layout"); 766 } 767 768 // Non-canonical case requires more work. 769 auto stridedExpr = 770 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 771 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 772 offset = AffineExpr(); 773 strides.clear(); 774 return failure(); 775 } 776 777 // Simplify results to allow folding to constants and simple checks. 778 unsigned numDims = m.getNumDims(); 779 unsigned numSymbols = m.getNumSymbols(); 780 offset = simplifyAffineExpr(offset, numDims, numSymbols); 781 for (auto &stride : strides) 782 stride = simplifyAffineExpr(stride, numDims, numSymbols); 783 784 /// In practice, a strided memref must be internally non-aliasing. Test 785 /// against 0 as a proxy. 786 /// TODO: static cases can have more advanced checks. 787 /// TODO: dynamic cases would require a way to compare symbolic 788 /// expressions and would probably need an affine set context propagated 789 /// everywhere. 790 if (llvm::any_of(strides, [](AffineExpr e) { 791 return e == getAffineConstantExpr(0, e.getContext()); 792 })) { 793 offset = AffineExpr(); 794 strides.clear(); 795 return failure(); 796 } 797 798 return success(); 799 } 800 801 LogicalResult mlir::getStridesAndOffset(MemRefType t, 802 SmallVectorImpl<int64_t> &strides, 803 int64_t &offset) { 804 AffineExpr offsetExpr; 805 SmallVector<AffineExpr, 4> strideExprs; 806 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 807 return failure(); 808 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 809 offset = cst.getValue(); 810 else 811 offset = ShapedType::kDynamicStrideOrOffset; 812 for (auto e : strideExprs) { 813 if (auto c = e.dyn_cast<AffineConstantExpr>()) 814 strides.push_back(c.getValue()); 815 else 816 strides.push_back(ShapedType::kDynamicStrideOrOffset); 817 } 818 return success(); 819 } 820 821 void UnrankedMemRefType::walkImmediateSubElements( 822 function_ref<void(Attribute)> walkAttrsFn, 823 function_ref<void(Type)> walkTypesFn) const { 824 walkTypesFn(getElementType()); 825 walkAttrsFn(getMemorySpace()); 826 } 827 828 //===----------------------------------------------------------------------===// 829 /// TupleType 830 //===----------------------------------------------------------------------===// 831 832 /// Return the elements types for this tuple. 833 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 834 835 /// Accumulate the types contained in this tuple and tuples nested within it. 836 /// Note that this only flattens nested tuples, not any other container type, 837 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 838 /// (i32, tensor<i32>, f32, i64) 839 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 840 for (Type type : getTypes()) { 841 if (auto nestedTuple = type.dyn_cast<TupleType>()) 842 nestedTuple.getFlattenedTypes(types); 843 else 844 types.push_back(type); 845 } 846 } 847 848 /// Return the number of element types. 849 size_t TupleType::size() const { return getImpl()->size(); } 850 851 void TupleType::walkImmediateSubElements( 852 function_ref<void(Attribute)> walkAttrsFn, 853 function_ref<void(Type)> walkTypesFn) const { 854 for (Type type : getTypes()) 855 walkTypesFn(type); 856 } 857 858 //===----------------------------------------------------------------------===// 859 // Type Utilities 860 //===----------------------------------------------------------------------===// 861 862 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 863 int64_t offset, 864 MLIRContext *context) { 865 AffineExpr expr; 866 unsigned nSymbols = 0; 867 868 // AffineExpr for offset. 869 // Static case. 870 if (offset != MemRefType::getDynamicStrideOrOffset()) { 871 auto cst = getAffineConstantExpr(offset, context); 872 expr = cst; 873 } else { 874 // Dynamic case, new symbol for the offset. 875 auto sym = getAffineSymbolExpr(nSymbols++, context); 876 expr = sym; 877 } 878 879 // AffineExpr for strides. 880 for (const auto &en : llvm::enumerate(strides)) { 881 auto dim = en.index(); 882 auto stride = en.value(); 883 assert(stride != 0 && "Invalid stride specification"); 884 auto d = getAffineDimExpr(dim, context); 885 AffineExpr mult; 886 // Static case. 887 if (stride != MemRefType::getDynamicStrideOrOffset()) 888 mult = getAffineConstantExpr(stride, context); 889 else 890 // Dynamic case, new symbol for each new stride. 891 mult = getAffineSymbolExpr(nSymbols++, context); 892 expr = expr + d * mult; 893 } 894 895 return AffineMap::get(strides.size(), nSymbols, expr); 896 } 897 898 /// Return a version of `t` with identity layout if it can be determined 899 /// statically that the layout is the canonical contiguous strided layout. 900 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 901 /// `t` with simplified layout. 902 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 903 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 904 AffineMap m = t.getLayout().getAffineMap(); 905 906 // Already in canonical form. 907 if (m.isIdentity()) 908 return t; 909 910 // Can't reduce to canonical identity form, return in canonical form. 911 if (m.getNumResults() > 1) 912 return t; 913 914 // Corner-case for 0-D affine maps. 915 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 916 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 917 if (cst.getValue() == 0) 918 return MemRefType::Builder(t).setLayout({}); 919 return t; 920 } 921 922 // 0-D corner case for empty shape that still have an affine map. Example: 923 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 924 // offset needs to remain, just return t. 925 if (t.getShape().empty()) 926 return t; 927 928 // If the canonical strided layout for the sizes of `t` is equal to the 929 // simplified layout of `t` we can just return an empty layout. Otherwise, 930 // just simplify the existing layout. 931 AffineExpr expr = 932 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 933 auto simplifiedLayoutExpr = 934 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 935 if (expr != simplifiedLayoutExpr) 936 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( 937 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); 938 return MemRefType::Builder(t).setLayout({}); 939 } 940 941 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 942 ArrayRef<AffineExpr> exprs, 943 MLIRContext *context) { 944 assert(!sizes.empty() && !exprs.empty() && 945 "expected non-empty sizes and exprs"); 946 947 // Size 0 corner case is useful for canonicalizations. 948 if (llvm::is_contained(sizes, 0)) 949 return getAffineConstantExpr(0, context); 950 951 auto maps = AffineMap::inferFromExprList(exprs); 952 assert(!maps.empty() && "Expected one non-empty map"); 953 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 954 955 AffineExpr expr; 956 bool dynamicPoisonBit = false; 957 int64_t runningSize = 1; 958 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 959 int64_t size = std::get<1>(en); 960 // Degenerate case, no size =-> no stride 961 if (size == 0) 962 continue; 963 AffineExpr dimExpr = std::get<0>(en); 964 AffineExpr stride = dynamicPoisonBit 965 ? getAffineSymbolExpr(nSymbols++, context) 966 : getAffineConstantExpr(runningSize, context); 967 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 968 if (size > 0) { 969 runningSize *= size; 970 assert(runningSize > 0 && "integer overflow in size computation"); 971 } else { 972 dynamicPoisonBit = true; 973 } 974 } 975 return simplifyAffineExpr(expr, numDims, nSymbols); 976 } 977 978 /// Return a version of `t` with a layout that has all dynamic offset and 979 /// strides. This is used to erase the static layout. 980 MemRefType mlir::eraseStridedLayout(MemRefType t) { 981 auto val = ShapedType::kDynamicStrideOrOffset; 982 return MemRefType::Builder(t).setLayout( 983 AffineMapAttr::get(makeStridedLinearLayoutMap( 984 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()))); 985 } 986 987 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 988 MLIRContext *context) { 989 SmallVector<AffineExpr, 4> exprs; 990 exprs.reserve(sizes.size()); 991 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 992 exprs.push_back(getAffineDimExpr(dim, context)); 993 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 994 } 995 996 /// Return true if the layout for `t` is compatible with strided semantics. 997 bool mlir::isStrided(MemRefType t) { 998 int64_t offset; 999 SmallVector<int64_t, 4> strides; 1000 auto res = getStridesAndOffset(t, strides, offset); 1001 return succeeded(res); 1002 } 1003 1004 /// Return the layout map in strided linear layout AffineMap form. 1005 /// Return null if the layout is not compatible with a strided layout. 1006 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1007 int64_t offset; 1008 SmallVector<int64_t, 4> strides; 1009 if (failed(getStridesAndOffset(t, strides, offset))) 1010 return AffineMap(); 1011 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1012 } 1013 1014 /// Return the AffineExpr representation of the offset, assuming `memRefType` 1015 /// is a strided memref. 1016 static AffineExpr getOffsetExpr(MemRefType memrefType) { 1017 SmallVector<AffineExpr> strides; 1018 AffineExpr offset; 1019 if (failed(getStridesAndOffset(memrefType, strides, offset))) 1020 assert(false && "expected strided memref"); 1021 return offset; 1022 } 1023 1024 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and 1025 /// `offset` AffineExpr. 1026 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, 1027 ArrayRef<int64_t> shape, 1028 Type elementType, 1029 AffineExpr offset) { 1030 AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); 1031 AffineExpr contiguousRowMajor = canonical + offset; 1032 AffineMap contiguousRowMajorMap = 1033 AffineMap::inferFromExprList({contiguousRowMajor})[0]; 1034 return MemRefType::get(shape, elementType, contiguousRowMajorMap); 1035 } 1036 1037 /// Helper determining if a memref is static-shape and contiguous-row-major 1038 /// layout, while still allowing for an arbitrary offset (any static or 1039 /// dynamic value). 1040 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { 1041 if (!memrefType.hasStaticShape()) 1042 return false; 1043 AffineExpr offset = getOffsetExpr(memrefType); 1044 MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( 1045 memrefType.getContext(), memrefType.getShape(), 1046 memrefType.getElementType(), offset); 1047 return canonicalizeStridedLayout(memrefType) == 1048 canonicalizeStridedLayout(contiguousRowMajorMemRefType); 1049 } 1050