1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/FunctionInterfaces.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/TensorEncoding.h" 20 #include "llvm/ADT/APFloat.h" 21 #include "llvm/ADT/BitVector.h" 22 #include "llvm/ADT/Sequence.h" 23 #include "llvm/ADT/Twine.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::detail; 28 29 //===----------------------------------------------------------------------===// 30 /// Tablegen Type Definitions 31 //===----------------------------------------------------------------------===// 32 33 #define GET_TYPEDEF_CLASSES 34 #include "mlir/IR/BuiltinTypes.cpp.inc" 35 36 //===----------------------------------------------------------------------===// 37 // BuiltinDialect 38 //===----------------------------------------------------------------------===// 39 40 void BuiltinDialect::registerTypes() { 41 addTypes< 42 #define GET_TYPEDEF_LIST 43 #include "mlir/IR/BuiltinTypes.cpp.inc" 44 >(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 /// ComplexType 49 //===----------------------------------------------------------------------===// 50 51 /// Verify the construction of an integer type. 52 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 53 Type elementType) { 54 if (!elementType.isIntOrFloat()) 55 return emitError() << "invalid element type for complex"; 56 return success(); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // Integer Type 61 //===----------------------------------------------------------------------===// 62 63 // static constexpr must have a definition (until in C++17 and inline variable). 64 constexpr unsigned IntegerType::kMaxWidth; 65 66 /// Verify the construction of an integer type. 67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 68 unsigned width, 69 SignednessSemantics signedness) { 70 if (width > IntegerType::kMaxWidth) { 71 return emitError() << "integer bitwidth is limited to " 72 << IntegerType::kMaxWidth << " bits"; 73 } 74 return success(); 75 } 76 77 unsigned IntegerType::getWidth() const { return getImpl()->width; } 78 79 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 80 return getImpl()->signedness; 81 } 82 83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 84 if (!scale) 85 return IntegerType(); 86 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Float Type 91 //===----------------------------------------------------------------------===// 92 93 unsigned FloatType::getWidth() { 94 if (isa<Float16Type, BFloat16Type>()) 95 return 16; 96 if (isa<Float32Type>()) 97 return 32; 98 if (isa<Float64Type>()) 99 return 64; 100 if (isa<Float80Type>()) 101 return 80; 102 if (isa<Float128Type>()) 103 return 128; 104 llvm_unreachable("unexpected float type"); 105 } 106 107 /// Returns the floating semantics for the given type. 108 const llvm::fltSemantics &FloatType::getFloatSemantics() { 109 if (isa<BFloat16Type>()) 110 return APFloat::BFloat(); 111 if (isa<Float16Type>()) 112 return APFloat::IEEEhalf(); 113 if (isa<Float32Type>()) 114 return APFloat::IEEEsingle(); 115 if (isa<Float64Type>()) 116 return APFloat::IEEEdouble(); 117 if (isa<Float80Type>()) 118 return APFloat::x87DoubleExtended(); 119 if (isa<Float128Type>()) 120 return APFloat::IEEEquad(); 121 llvm_unreachable("non-floating point type used"); 122 } 123 124 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 125 if (!scale) 126 return FloatType(); 127 MLIRContext *ctx = getContext(); 128 if (isF16() || isBF16()) { 129 if (scale == 2) 130 return FloatType::getF32(ctx); 131 if (scale == 4) 132 return FloatType::getF64(ctx); 133 } 134 if (isF32()) 135 if (scale == 2) 136 return FloatType::getF64(ctx); 137 return FloatType(); 138 } 139 140 unsigned FloatType::getFPMantissaWidth() { 141 return APFloat::semanticsPrecision(getFloatSemantics()); 142 } 143 144 //===----------------------------------------------------------------------===// 145 // FunctionType 146 //===----------------------------------------------------------------------===// 147 148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 149 150 ArrayRef<Type> FunctionType::getInputs() const { 151 return getImpl()->getInputs(); 152 } 153 154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 155 156 ArrayRef<Type> FunctionType::getResults() const { 157 return getImpl()->getResults(); 158 } 159 160 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { 161 return get(getContext(), inputs, results); 162 } 163 164 /// Returns a new function type with the specified arguments and results 165 /// inserted. 166 FunctionType FunctionType::getWithArgsAndResults( 167 ArrayRef<unsigned> argIndices, TypeRange argTypes, 168 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 169 SmallVector<Type> argStorage, resultStorage; 170 TypeRange newArgTypes = function_interface_impl::insertTypesInto( 171 getInputs(), argIndices, argTypes, argStorage); 172 TypeRange newResultTypes = function_interface_impl::insertTypesInto( 173 getResults(), resultIndices, resultTypes, resultStorage); 174 return clone(newArgTypes, newResultTypes); 175 } 176 177 /// Returns a new function type without the specified arguments and results. 178 FunctionType 179 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, 180 const BitVector &resultIndices) { 181 SmallVector<Type> argStorage, resultStorage; 182 TypeRange newArgTypes = function_interface_impl::filterTypesOut( 183 getInputs(), argIndices, argStorage); 184 TypeRange newResultTypes = function_interface_impl::filterTypesOut( 185 getResults(), resultIndices, resultStorage); 186 return clone(newArgTypes, newResultTypes); 187 } 188 189 void FunctionType::walkImmediateSubElements( 190 function_ref<void(Attribute)> walkAttrsFn, 191 function_ref<void(Type)> walkTypesFn) const { 192 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 193 walkTypesFn(type); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // OpaqueType 198 //===----------------------------------------------------------------------===// 199 200 /// Verify the construction of an opaque type. 201 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 202 StringAttr dialect, StringRef typeData) { 203 if (!Dialect::isValidNamespace(dialect.strref())) 204 return emitError() << "invalid dialect namespace '" << dialect << "'"; 205 206 // Check that the dialect is actually registered. 207 MLIRContext *context = dialect.getContext(); 208 if (!context->allowsUnregisteredDialects() && 209 !context->getLoadedDialect(dialect.strref())) { 210 return emitError() 211 << "`!" << dialect << "<\"" << typeData << "\">" 212 << "` type created with unregistered dialect. If this is " 213 "intended, please call allowUnregisteredDialects() on the " 214 "MLIRContext, or use -allow-unregistered-dialect with " 215 "the MLIR opt tool used"; 216 } 217 218 return success(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // VectorType 223 //===----------------------------------------------------------------------===// 224 225 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 226 ArrayRef<int64_t> shape, Type elementType, 227 unsigned numScalableDims) { 228 if (!isValidElementType(elementType)) 229 return emitError() 230 << "vector elements must be int/index/float type but got " 231 << elementType; 232 233 if (any_of(shape, [](int64_t i) { return i <= 0; })) 234 return emitError() 235 << "vector types must have positive constant sizes but got " 236 << shape; 237 238 return success(); 239 } 240 241 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 242 if (!scale) 243 return VectorType(); 244 if (auto et = getElementType().dyn_cast<IntegerType>()) 245 if (auto scaledEt = et.scaleElementBitwidth(scale)) 246 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 247 if (auto et = getElementType().dyn_cast<FloatType>()) 248 if (auto scaledEt = et.scaleElementBitwidth(scale)) 249 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 250 return VectorType(); 251 } 252 253 void VectorType::walkImmediateSubElements( 254 function_ref<void(Attribute)> walkAttrsFn, 255 function_ref<void(Type)> walkTypesFn) const { 256 walkTypesFn(getElementType()); 257 } 258 259 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 260 Type elementType) const { 261 return VectorType::get(shape.value_or(getShape()), elementType, 262 getNumScalableDims()); 263 } 264 265 //===----------------------------------------------------------------------===// 266 // TensorType 267 //===----------------------------------------------------------------------===// 268 269 Type TensorType::getElementType() const { 270 return llvm::TypeSwitch<TensorType, Type>(*this) 271 .Case<RankedTensorType, UnrankedTensorType>( 272 [](auto type) { return type.getElementType(); }); 273 } 274 275 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); } 276 277 ArrayRef<int64_t> TensorType::getShape() const { 278 return cast<RankedTensorType>().getShape(); 279 } 280 281 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 282 Type elementType) const { 283 if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) { 284 if (shape) 285 return RankedTensorType::get(*shape, elementType); 286 return UnrankedTensorType::get(elementType); 287 } 288 289 auto rankedTy = cast<RankedTensorType>(); 290 if (!shape) 291 return RankedTensorType::get(rankedTy.getShape(), elementType, 292 rankedTy.getEncoding()); 293 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, 294 rankedTy.getEncoding()); 295 } 296 297 // Check if "elementType" can be an element type of a tensor. 298 static LogicalResult 299 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 300 Type elementType) { 301 if (!TensorType::isValidElementType(elementType)) 302 return emitError() << "invalid tensor element type: " << elementType; 303 return success(); 304 } 305 306 /// Return true if the specified element type is ok in a tensor. 307 bool TensorType::isValidElementType(Type type) { 308 // Note: Non standard/builtin types are allowed to exist within tensor 309 // types. Dialects are expected to verify that tensor types have a valid 310 // element type within that dialect. 311 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 312 IndexType>() || 313 !llvm::isa<BuiltinDialect>(type.getDialect()); 314 } 315 316 //===----------------------------------------------------------------------===// 317 // RankedTensorType 318 //===----------------------------------------------------------------------===// 319 320 LogicalResult 321 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 322 ArrayRef<int64_t> shape, Type elementType, 323 Attribute encoding) { 324 for (int64_t s : shape) 325 if (s < -1) 326 return emitError() << "invalid tensor dimension size"; 327 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 328 if (failed(v.verifyEncoding(shape, elementType, emitError))) 329 return failure(); 330 return checkTensorElementType(emitError, elementType); 331 } 332 333 void RankedTensorType::walkImmediateSubElements( 334 function_ref<void(Attribute)> walkAttrsFn, 335 function_ref<void(Type)> walkTypesFn) const { 336 walkTypesFn(getElementType()); 337 if (Attribute encoding = getEncoding()) 338 walkAttrsFn(encoding); 339 } 340 341 //===----------------------------------------------------------------------===// 342 // UnrankedTensorType 343 //===----------------------------------------------------------------------===// 344 345 LogicalResult 346 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 347 Type elementType) { 348 return checkTensorElementType(emitError, elementType); 349 } 350 351 void UnrankedTensorType::walkImmediateSubElements( 352 function_ref<void(Attribute)> walkAttrsFn, 353 function_ref<void(Type)> walkTypesFn) const { 354 walkTypesFn(getElementType()); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // BaseMemRefType 359 //===----------------------------------------------------------------------===// 360 361 Type BaseMemRefType::getElementType() const { 362 return llvm::TypeSwitch<BaseMemRefType, Type>(*this) 363 .Case<MemRefType, UnrankedMemRefType>( 364 [](auto type) { return type.getElementType(); }); 365 } 366 367 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); } 368 369 ArrayRef<int64_t> BaseMemRefType::getShape() const { 370 return cast<MemRefType>().getShape(); 371 } 372 373 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape, 374 Type elementType) const { 375 if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) { 376 if (!shape) 377 return UnrankedMemRefType::get(elementType, getMemorySpace()); 378 MemRefType::Builder builder(*shape, elementType); 379 builder.setMemorySpace(getMemorySpace()); 380 return builder; 381 } 382 383 MemRefType::Builder builder(cast<MemRefType>()); 384 if (shape) 385 builder.setShape(*shape); 386 builder.setElementType(elementType); 387 return builder; 388 } 389 390 Attribute BaseMemRefType::getMemorySpace() const { 391 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 392 return rankedMemRefTy.getMemorySpace(); 393 return cast<UnrankedMemRefType>().getMemorySpace(); 394 } 395 396 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 397 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 398 return rankedMemRefTy.getMemorySpaceAsInt(); 399 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 400 } 401 402 //===----------------------------------------------------------------------===// 403 // MemRefType 404 //===----------------------------------------------------------------------===// 405 406 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 407 /// `originalShape` with some `1` entries erased, return the set of indices 408 /// that specifies which of the entries of `originalShape` are dropped to obtain 409 /// `reducedShape`. The returned mask can be applied as a projection to 410 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 411 /// which dimensions must be kept when e.g. compute MemRef strides under 412 /// rank-reducing operations. Return None if reducedShape cannot be obtained 413 /// by dropping only `1` entries in `originalShape`. 414 llvm::Optional<llvm::SmallDenseSet<unsigned>> 415 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 416 ArrayRef<int64_t> reducedShape) { 417 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 418 llvm::SmallDenseSet<unsigned> unusedDims; 419 unsigned reducedIdx = 0; 420 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 421 // Greedily insert `originalIdx` if match. 422 if (reducedIdx < reducedRank && 423 originalShape[originalIdx] == reducedShape[reducedIdx]) { 424 reducedIdx++; 425 continue; 426 } 427 428 unusedDims.insert(originalIdx); 429 // If no match on `originalIdx`, the `originalShape` at this dimension 430 // must be 1, otherwise we bail. 431 if (originalShape[originalIdx] != 1) 432 return llvm::None; 433 } 434 // The whole reducedShape must be scanned, otherwise we bail. 435 if (reducedIdx != reducedRank) 436 return llvm::None; 437 return unusedDims; 438 } 439 440 SliceVerificationResult 441 mlir::isRankReducedType(ShapedType originalType, 442 ShapedType candidateReducedType) { 443 if (originalType == candidateReducedType) 444 return SliceVerificationResult::Success; 445 446 ShapedType originalShapedType = originalType.cast<ShapedType>(); 447 ShapedType candidateReducedShapedType = 448 candidateReducedType.cast<ShapedType>(); 449 450 // Rank and size logic is valid for all ShapedTypes. 451 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 452 ArrayRef<int64_t> candidateReducedShape = 453 candidateReducedShapedType.getShape(); 454 unsigned originalRank = originalShape.size(), 455 candidateReducedRank = candidateReducedShape.size(); 456 if (candidateReducedRank > originalRank) 457 return SliceVerificationResult::RankTooLarge; 458 459 auto optionalUnusedDimsMask = 460 computeRankReductionMask(originalShape, candidateReducedShape); 461 462 // Sizes cannot be matched in case empty vector is returned. 463 if (!optionalUnusedDimsMask) 464 return SliceVerificationResult::SizeMismatch; 465 466 if (originalShapedType.getElementType() != 467 candidateReducedShapedType.getElementType()) 468 return SliceVerificationResult::ElemTypeMismatch; 469 470 return SliceVerificationResult::Success; 471 } 472 473 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 474 // Empty attribute is allowed as default memory space. 475 if (!memorySpace) 476 return true; 477 478 // Supported built-in attributes. 479 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 480 return true; 481 482 // Allow custom dialect attributes. 483 if (!isa<BuiltinDialect>(memorySpace.getDialect())) 484 return true; 485 486 return false; 487 } 488 489 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 490 MLIRContext *ctx) { 491 if (memorySpace == 0) 492 return nullptr; 493 494 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 495 } 496 497 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 498 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 499 if (intMemorySpace && intMemorySpace.getValue() == 0) 500 return nullptr; 501 502 return memorySpace; 503 } 504 505 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 506 if (!memorySpace) 507 return 0; 508 509 assert(memorySpace.isa<IntegerAttr>() && 510 "Using `getMemorySpaceInteger` with non-Integer attribute"); 511 512 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 513 } 514 515 MemRefType::Builder & 516 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 517 memorySpace = 518 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 519 return *this; 520 } 521 522 unsigned MemRefType::getMemorySpaceAsInt() const { 523 return detail::getMemorySpaceAsInt(getMemorySpace()); 524 } 525 526 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 527 MemRefLayoutAttrInterface layout, 528 Attribute memorySpace) { 529 // Use default layout for empty attribute. 530 if (!layout) 531 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 532 shape.size(), elementType.getContext())); 533 534 // Drop default memory space value and replace it with empty attribute. 535 memorySpace = skipDefaultMemorySpace(memorySpace); 536 537 return Base::get(elementType.getContext(), shape, elementType, layout, 538 memorySpace); 539 } 540 541 MemRefType MemRefType::getChecked( 542 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 543 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 544 545 // Use default layout for empty attribute. 546 if (!layout) 547 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 548 shape.size(), elementType.getContext())); 549 550 // Drop default memory space value and replace it with empty attribute. 551 memorySpace = skipDefaultMemorySpace(memorySpace); 552 553 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 554 elementType, layout, memorySpace); 555 } 556 557 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 558 AffineMap map, Attribute memorySpace) { 559 560 // Use default layout for empty map. 561 if (!map) 562 map = AffineMap::getMultiDimIdentityMap(shape.size(), 563 elementType.getContext()); 564 565 // Wrap AffineMap into Attribute. 566 Attribute layout = AffineMapAttr::get(map); 567 568 // Drop default memory space value and replace it with empty attribute. 569 memorySpace = skipDefaultMemorySpace(memorySpace); 570 571 return Base::get(elementType.getContext(), shape, elementType, layout, 572 memorySpace); 573 } 574 575 MemRefType 576 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 577 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 578 Attribute memorySpace) { 579 580 // Use default layout for empty map. 581 if (!map) 582 map = AffineMap::getMultiDimIdentityMap(shape.size(), 583 elementType.getContext()); 584 585 // Wrap AffineMap into Attribute. 586 Attribute layout = AffineMapAttr::get(map); 587 588 // Drop default memory space value and replace it with empty attribute. 589 memorySpace = skipDefaultMemorySpace(memorySpace); 590 591 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 592 elementType, layout, memorySpace); 593 } 594 595 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 596 AffineMap map, unsigned memorySpaceInd) { 597 598 // Use default layout for empty map. 599 if (!map) 600 map = AffineMap::getMultiDimIdentityMap(shape.size(), 601 elementType.getContext()); 602 603 // Wrap AffineMap into Attribute. 604 Attribute layout = AffineMapAttr::get(map); 605 606 // Convert deprecated integer-like memory space to Attribute. 607 Attribute memorySpace = 608 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 609 610 return Base::get(elementType.getContext(), shape, elementType, layout, 611 memorySpace); 612 } 613 614 MemRefType 615 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 616 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 617 unsigned memorySpaceInd) { 618 619 // Use default layout for empty map. 620 if (!map) 621 map = AffineMap::getMultiDimIdentityMap(shape.size(), 622 elementType.getContext()); 623 624 // Wrap AffineMap into Attribute. 625 Attribute layout = AffineMapAttr::get(map); 626 627 // Convert deprecated integer-like memory space to Attribute. 628 Attribute memorySpace = 629 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 630 631 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 632 elementType, layout, memorySpace); 633 } 634 635 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 636 ArrayRef<int64_t> shape, Type elementType, 637 MemRefLayoutAttrInterface layout, 638 Attribute memorySpace) { 639 if (!BaseMemRefType::isValidElementType(elementType)) 640 return emitError() << "invalid memref element type"; 641 642 // Negative sizes are not allowed except for `-1` that means dynamic size. 643 for (int64_t s : shape) 644 if (s < -1) 645 return emitError() << "invalid memref size"; 646 647 assert(layout && "missing layout specification"); 648 if (failed(layout.verifyLayout(shape, emitError))) 649 return failure(); 650 651 if (!isSupportedMemorySpace(memorySpace)) 652 return emitError() << "unsupported memory space Attribute"; 653 654 return success(); 655 } 656 657 void MemRefType::walkImmediateSubElements( 658 function_ref<void(Attribute)> walkAttrsFn, 659 function_ref<void(Type)> walkTypesFn) const { 660 walkTypesFn(getElementType()); 661 if (!getLayout().isIdentity()) 662 walkAttrsFn(getLayout()); 663 walkAttrsFn(getMemorySpace()); 664 } 665 666 //===----------------------------------------------------------------------===// 667 // UnrankedMemRefType 668 //===----------------------------------------------------------------------===// 669 670 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 671 return detail::getMemorySpaceAsInt(getMemorySpace()); 672 } 673 674 LogicalResult 675 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 676 Type elementType, Attribute memorySpace) { 677 if (!BaseMemRefType::isValidElementType(elementType)) 678 return emitError() << "invalid memref element type"; 679 680 if (!isSupportedMemorySpace(memorySpace)) 681 return emitError() << "unsupported memory space Attribute"; 682 683 return success(); 684 } 685 686 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 687 // i.e. single term). Accumulate the AffineExpr into the existing one. 688 static void extractStridesFromTerm(AffineExpr e, 689 AffineExpr multiplicativeFactor, 690 MutableArrayRef<AffineExpr> strides, 691 AffineExpr &offset) { 692 if (auto dim = e.dyn_cast<AffineDimExpr>()) 693 strides[dim.getPosition()] = 694 strides[dim.getPosition()] + multiplicativeFactor; 695 else 696 offset = offset + e * multiplicativeFactor; 697 } 698 699 /// Takes a single AffineExpr `e` and populates the `strides` array with the 700 /// strides expressions for each dim position. 701 /// The convention is that the strides for dimensions d0, .. dn appear in 702 /// order to make indexing intuitive into the result. 703 static LogicalResult extractStrides(AffineExpr e, 704 AffineExpr multiplicativeFactor, 705 MutableArrayRef<AffineExpr> strides, 706 AffineExpr &offset) { 707 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 708 if (!bin) { 709 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 710 return success(); 711 } 712 713 if (bin.getKind() == AffineExprKind::CeilDiv || 714 bin.getKind() == AffineExprKind::FloorDiv || 715 bin.getKind() == AffineExprKind::Mod) 716 return failure(); 717 718 if (bin.getKind() == AffineExprKind::Mul) { 719 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 720 if (dim) { 721 strides[dim.getPosition()] = 722 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 723 return success(); 724 } 725 // LHS and RHS may both contain complex expressions of dims. Try one path 726 // and if it fails try the other. This is guaranteed to succeed because 727 // only one path may have a `dim`, otherwise this is not an AffineExpr in 728 // the first place. 729 if (bin.getLHS().isSymbolicOrConstant()) 730 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 731 strides, offset); 732 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 733 strides, offset); 734 } 735 736 if (bin.getKind() == AffineExprKind::Add) { 737 auto res1 = 738 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 739 auto res2 = 740 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 741 return success(succeeded(res1) && succeeded(res2)); 742 } 743 744 llvm_unreachable("unexpected binary operation"); 745 } 746 747 LogicalResult mlir::getStridesAndOffset(MemRefType t, 748 SmallVectorImpl<AffineExpr> &strides, 749 AffineExpr &offset) { 750 AffineMap m = t.getLayout().getAffineMap(); 751 752 if (m.getNumResults() != 1 && !m.isIdentity()) 753 return failure(); 754 755 auto zero = getAffineConstantExpr(0, t.getContext()); 756 auto one = getAffineConstantExpr(1, t.getContext()); 757 offset = zero; 758 strides.assign(t.getRank(), zero); 759 760 // Canonical case for empty map. 761 if (m.isIdentity()) { 762 // 0-D corner case, offset is already 0. 763 if (t.getRank() == 0) 764 return success(); 765 auto stridedExpr = 766 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 767 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 768 return success(); 769 assert(false && "unexpected failure: extract strides in canonical layout"); 770 } 771 772 // Non-canonical case requires more work. 773 auto stridedExpr = 774 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 775 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 776 offset = AffineExpr(); 777 strides.clear(); 778 return failure(); 779 } 780 781 // Simplify results to allow folding to constants and simple checks. 782 unsigned numDims = m.getNumDims(); 783 unsigned numSymbols = m.getNumSymbols(); 784 offset = simplifyAffineExpr(offset, numDims, numSymbols); 785 for (auto &stride : strides) 786 stride = simplifyAffineExpr(stride, numDims, numSymbols); 787 788 /// In practice, a strided memref must be internally non-aliasing. Test 789 /// against 0 as a proxy. 790 /// TODO: static cases can have more advanced checks. 791 /// TODO: dynamic cases would require a way to compare symbolic 792 /// expressions and would probably need an affine set context propagated 793 /// everywhere. 794 if (llvm::any_of(strides, [](AffineExpr e) { 795 return e == getAffineConstantExpr(0, e.getContext()); 796 })) { 797 offset = AffineExpr(); 798 strides.clear(); 799 return failure(); 800 } 801 802 return success(); 803 } 804 805 LogicalResult mlir::getStridesAndOffset(MemRefType t, 806 SmallVectorImpl<int64_t> &strides, 807 int64_t &offset) { 808 AffineExpr offsetExpr; 809 SmallVector<AffineExpr, 4> strideExprs; 810 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 811 return failure(); 812 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 813 offset = cst.getValue(); 814 else 815 offset = ShapedType::kDynamicStrideOrOffset; 816 for (auto e : strideExprs) { 817 if (auto c = e.dyn_cast<AffineConstantExpr>()) 818 strides.push_back(c.getValue()); 819 else 820 strides.push_back(ShapedType::kDynamicStrideOrOffset); 821 } 822 return success(); 823 } 824 825 void UnrankedMemRefType::walkImmediateSubElements( 826 function_ref<void(Attribute)> walkAttrsFn, 827 function_ref<void(Type)> walkTypesFn) const { 828 walkTypesFn(getElementType()); 829 walkAttrsFn(getMemorySpace()); 830 } 831 832 //===----------------------------------------------------------------------===// 833 /// TupleType 834 //===----------------------------------------------------------------------===// 835 836 /// Return the elements types for this tuple. 837 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 838 839 /// Accumulate the types contained in this tuple and tuples nested within it. 840 /// Note that this only flattens nested tuples, not any other container type, 841 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 842 /// (i32, tensor<i32>, f32, i64) 843 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 844 for (Type type : getTypes()) { 845 if (auto nestedTuple = type.dyn_cast<TupleType>()) 846 nestedTuple.getFlattenedTypes(types); 847 else 848 types.push_back(type); 849 } 850 } 851 852 /// Return the number of element types. 853 size_t TupleType::size() const { return getImpl()->size(); } 854 855 void TupleType::walkImmediateSubElements( 856 function_ref<void(Attribute)> walkAttrsFn, 857 function_ref<void(Type)> walkTypesFn) const { 858 for (Type type : getTypes()) 859 walkTypesFn(type); 860 } 861 862 //===----------------------------------------------------------------------===// 863 // Type Utilities 864 //===----------------------------------------------------------------------===// 865 866 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 867 int64_t offset, 868 MLIRContext *context) { 869 AffineExpr expr; 870 unsigned nSymbols = 0; 871 872 // AffineExpr for offset. 873 // Static case. 874 if (offset != MemRefType::getDynamicStrideOrOffset()) { 875 auto cst = getAffineConstantExpr(offset, context); 876 expr = cst; 877 } else { 878 // Dynamic case, new symbol for the offset. 879 auto sym = getAffineSymbolExpr(nSymbols++, context); 880 expr = sym; 881 } 882 883 // AffineExpr for strides. 884 for (const auto &en : llvm::enumerate(strides)) { 885 auto dim = en.index(); 886 auto stride = en.value(); 887 assert(stride != 0 && "Invalid stride specification"); 888 auto d = getAffineDimExpr(dim, context); 889 AffineExpr mult; 890 // Static case. 891 if (stride != MemRefType::getDynamicStrideOrOffset()) 892 mult = getAffineConstantExpr(stride, context); 893 else 894 // Dynamic case, new symbol for each new stride. 895 mult = getAffineSymbolExpr(nSymbols++, context); 896 expr = expr + d * mult; 897 } 898 899 return AffineMap::get(strides.size(), nSymbols, expr); 900 } 901 902 /// Return a version of `t` with identity layout if it can be determined 903 /// statically that the layout is the canonical contiguous strided layout. 904 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 905 /// `t` with simplified layout. 906 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 907 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 908 AffineMap m = t.getLayout().getAffineMap(); 909 910 // Already in canonical form. 911 if (m.isIdentity()) 912 return t; 913 914 // Can't reduce to canonical identity form, return in canonical form. 915 if (m.getNumResults() > 1) 916 return t; 917 918 // Corner-case for 0-D affine maps. 919 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 920 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 921 if (cst.getValue() == 0) 922 return MemRefType::Builder(t).setLayout({}); 923 return t; 924 } 925 926 // 0-D corner case for empty shape that still have an affine map. Example: 927 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 928 // offset needs to remain, just return t. 929 if (t.getShape().empty()) 930 return t; 931 932 // If the canonical strided layout for the sizes of `t` is equal to the 933 // simplified layout of `t` we can just return an empty layout. Otherwise, 934 // just simplify the existing layout. 935 AffineExpr expr = 936 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 937 auto simplifiedLayoutExpr = 938 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 939 if (expr != simplifiedLayoutExpr) 940 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( 941 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); 942 return MemRefType::Builder(t).setLayout({}); 943 } 944 945 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 946 ArrayRef<AffineExpr> exprs, 947 MLIRContext *context) { 948 // Size 0 corner case is useful for canonicalizations. 949 if (sizes.empty() || llvm::is_contained(sizes, 0)) 950 return getAffineConstantExpr(0, context); 951 952 assert(!exprs.empty() && "expected exprs"); 953 auto maps = AffineMap::inferFromExprList(exprs); 954 assert(!maps.empty() && "Expected one non-empty map"); 955 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 956 957 AffineExpr expr; 958 bool dynamicPoisonBit = false; 959 int64_t runningSize = 1; 960 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 961 int64_t size = std::get<1>(en); 962 // Degenerate case, no size =-> no stride 963 if (size == 0) 964 continue; 965 AffineExpr dimExpr = std::get<0>(en); 966 AffineExpr stride = dynamicPoisonBit 967 ? getAffineSymbolExpr(nSymbols++, context) 968 : getAffineConstantExpr(runningSize, context); 969 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 970 if (size > 0) { 971 runningSize *= size; 972 assert(runningSize > 0 && "integer overflow in size computation"); 973 } else { 974 dynamicPoisonBit = true; 975 } 976 } 977 return simplifyAffineExpr(expr, numDims, nSymbols); 978 } 979 980 /// Return a version of `t` with a layout that has all dynamic offset and 981 /// strides. This is used to erase the static layout. 982 MemRefType mlir::eraseStridedLayout(MemRefType t) { 983 auto val = ShapedType::kDynamicStrideOrOffset; 984 return MemRefType::Builder(t).setLayout( 985 AffineMapAttr::get(makeStridedLinearLayoutMap( 986 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()))); 987 } 988 989 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 990 MLIRContext *context) { 991 SmallVector<AffineExpr, 4> exprs; 992 exprs.reserve(sizes.size()); 993 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 994 exprs.push_back(getAffineDimExpr(dim, context)); 995 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 996 } 997 998 /// Return true if the layout for `t` is compatible with strided semantics. 999 bool mlir::isStrided(MemRefType t) { 1000 int64_t offset; 1001 SmallVector<int64_t, 4> strides; 1002 auto res = getStridesAndOffset(t, strides, offset); 1003 return succeeded(res); 1004 } 1005 1006 /// Return the layout map in strided linear layout AffineMap form. 1007 /// Return null if the layout is not compatible with a strided layout. 1008 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1009 int64_t offset; 1010 SmallVector<int64_t, 4> strides; 1011 if (failed(getStridesAndOffset(t, strides, offset))) 1012 return AffineMap(); 1013 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1014 } 1015 1016 /// Return the AffineExpr representation of the offset, assuming `memRefType` 1017 /// is a strided memref. 1018 static AffineExpr getOffsetExpr(MemRefType memrefType) { 1019 SmallVector<AffineExpr> strides; 1020 AffineExpr offset; 1021 if (failed(getStridesAndOffset(memrefType, strides, offset))) 1022 assert(false && "expected strided memref"); 1023 return offset; 1024 } 1025 1026 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and 1027 /// `offset` AffineExpr. 1028 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, 1029 ArrayRef<int64_t> shape, 1030 Type elementType, 1031 AffineExpr offset) { 1032 AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); 1033 AffineExpr contiguousRowMajor = canonical + offset; 1034 AffineMap contiguousRowMajorMap = 1035 AffineMap::inferFromExprList({contiguousRowMajor})[0]; 1036 return MemRefType::get(shape, elementType, contiguousRowMajorMap); 1037 } 1038 1039 /// Helper determining if a memref is static-shape and contiguous-row-major 1040 /// layout, while still allowing for an arbitrary offset (any static or 1041 /// dynamic value). 1042 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { 1043 if (!memrefType.hasStaticShape()) 1044 return false; 1045 AffineExpr offset = getOffsetExpr(memrefType); 1046 MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( 1047 memrefType.getContext(), memrefType.getShape(), 1048 memrefType.getElementType(), offset); 1049 return canonicalizeStridedLayout(memrefType) == 1050 canonicalizeStridedLayout(contiguousRowMajorMemRefType); 1051 } 1052