1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/FunctionInterfaces.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/TensorEncoding.h" 20 #include "llvm/ADT/APFloat.h" 21 #include "llvm/ADT/BitVector.h" 22 #include "llvm/ADT/Sequence.h" 23 #include "llvm/ADT/Twine.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 26 using namespace mlir; 27 using namespace mlir::detail; 28 29 //===----------------------------------------------------------------------===// 30 /// Tablegen Type Definitions 31 //===----------------------------------------------------------------------===// 32 33 #define GET_TYPEDEF_CLASSES 34 #include "mlir/IR/BuiltinTypes.cpp.inc" 35 36 //===----------------------------------------------------------------------===// 37 // BuiltinDialect 38 //===----------------------------------------------------------------------===// 39 40 void BuiltinDialect::registerTypes() { 41 addTypes< 42 #define GET_TYPEDEF_LIST 43 #include "mlir/IR/BuiltinTypes.cpp.inc" 44 >(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 /// ComplexType 49 //===----------------------------------------------------------------------===// 50 51 /// Verify the construction of an integer type. 52 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 53 Type elementType) { 54 if (!elementType.isIntOrFloat()) 55 return emitError() << "invalid element type for complex"; 56 return success(); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // Integer Type 61 //===----------------------------------------------------------------------===// 62 63 // static constexpr must have a definition (until in C++17 and inline variable). 64 constexpr unsigned IntegerType::kMaxWidth; 65 66 /// Verify the construction of an integer type. 67 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 68 unsigned width, 69 SignednessSemantics signedness) { 70 if (width > IntegerType::kMaxWidth) { 71 return emitError() << "integer bitwidth is limited to " 72 << IntegerType::kMaxWidth << " bits"; 73 } 74 return success(); 75 } 76 77 unsigned IntegerType::getWidth() const { return getImpl()->width; } 78 79 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 80 return getImpl()->signedness; 81 } 82 83 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 84 if (!scale) 85 return IntegerType(); 86 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Float Type 91 //===----------------------------------------------------------------------===// 92 93 unsigned FloatType::getWidth() { 94 if (isa<Float16Type, BFloat16Type>()) 95 return 16; 96 if (isa<Float32Type>()) 97 return 32; 98 if (isa<Float64Type>()) 99 return 64; 100 if (isa<Float80Type>()) 101 return 80; 102 if (isa<Float128Type>()) 103 return 128; 104 llvm_unreachable("unexpected float type"); 105 } 106 107 /// Returns the floating semantics for the given type. 108 const llvm::fltSemantics &FloatType::getFloatSemantics() { 109 if (isa<BFloat16Type>()) 110 return APFloat::BFloat(); 111 if (isa<Float16Type>()) 112 return APFloat::IEEEhalf(); 113 if (isa<Float32Type>()) 114 return APFloat::IEEEsingle(); 115 if (isa<Float64Type>()) 116 return APFloat::IEEEdouble(); 117 if (isa<Float80Type>()) 118 return APFloat::x87DoubleExtended(); 119 if (isa<Float128Type>()) 120 return APFloat::IEEEquad(); 121 llvm_unreachable("non-floating point type used"); 122 } 123 124 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 125 if (!scale) 126 return FloatType(); 127 MLIRContext *ctx = getContext(); 128 if (isF16() || isBF16()) { 129 if (scale == 2) 130 return FloatType::getF32(ctx); 131 if (scale == 4) 132 return FloatType::getF64(ctx); 133 } 134 if (isF32()) 135 if (scale == 2) 136 return FloatType::getF64(ctx); 137 return FloatType(); 138 } 139 140 unsigned FloatType::getFPMantissaWidth() { 141 return APFloat::semanticsPrecision(getFloatSemantics()); 142 } 143 144 //===----------------------------------------------------------------------===// 145 // FunctionType 146 //===----------------------------------------------------------------------===// 147 148 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 149 150 ArrayRef<Type> FunctionType::getInputs() const { 151 return getImpl()->getInputs(); 152 } 153 154 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 155 156 ArrayRef<Type> FunctionType::getResults() const { 157 return getImpl()->getResults(); 158 } 159 160 FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { 161 return get(getContext(), inputs, results); 162 } 163 164 /// Returns a new function type with the specified arguments and results 165 /// inserted. 166 FunctionType FunctionType::getWithArgsAndResults( 167 ArrayRef<unsigned> argIndices, TypeRange argTypes, 168 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { 169 SmallVector<Type> argStorage, resultStorage; 170 TypeRange newArgTypes = function_interface_impl::insertTypesInto( 171 getInputs(), argIndices, argTypes, argStorage); 172 TypeRange newResultTypes = function_interface_impl::insertTypesInto( 173 getResults(), resultIndices, resultTypes, resultStorage); 174 return clone(newArgTypes, newResultTypes); 175 } 176 177 /// Returns a new function type without the specified arguments and results. 178 FunctionType 179 FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, 180 const BitVector &resultIndices) { 181 SmallVector<Type> argStorage, resultStorage; 182 TypeRange newArgTypes = function_interface_impl::filterTypesOut( 183 getInputs(), argIndices, argStorage); 184 TypeRange newResultTypes = function_interface_impl::filterTypesOut( 185 getResults(), resultIndices, resultStorage); 186 return clone(newArgTypes, newResultTypes); 187 } 188 189 void FunctionType::walkImmediateSubElements( 190 function_ref<void(Attribute)> walkAttrsFn, 191 function_ref<void(Type)> walkTypesFn) const { 192 for (Type type : llvm::concat<const Type>(getInputs(), getResults())) 193 walkTypesFn(type); 194 } 195 196 Type FunctionType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 197 ArrayRef<Type> replTypes) const { 198 unsigned numInputs = getNumInputs(); 199 return get(getContext(), replTypes.take_front(numInputs), 200 replTypes.drop_front(numInputs)); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // OpaqueType 205 //===----------------------------------------------------------------------===// 206 207 /// Verify the construction of an opaque type. 208 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 209 StringAttr dialect, StringRef typeData) { 210 if (!Dialect::isValidNamespace(dialect.strref())) 211 return emitError() << "invalid dialect namespace '" << dialect << "'"; 212 213 // Check that the dialect is actually registered. 214 MLIRContext *context = dialect.getContext(); 215 if (!context->allowsUnregisteredDialects() && 216 !context->getLoadedDialect(dialect.strref())) { 217 return emitError() 218 << "`!" << dialect << "<\"" << typeData << "\">" 219 << "` type created with unregistered dialect. If this is " 220 "intended, please call allowUnregisteredDialects() on the " 221 "MLIRContext, or use -allow-unregistered-dialect with " 222 "the MLIR opt tool used"; 223 } 224 225 return success(); 226 } 227 228 //===----------------------------------------------------------------------===// 229 // VectorType 230 //===----------------------------------------------------------------------===// 231 232 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 233 ArrayRef<int64_t> shape, Type elementType, 234 unsigned numScalableDims) { 235 if (!isValidElementType(elementType)) 236 return emitError() 237 << "vector elements must be int/index/float type but got " 238 << elementType; 239 240 if (any_of(shape, [](int64_t i) { return i <= 0; })) 241 return emitError() 242 << "vector types must have positive constant sizes but got " 243 << shape; 244 245 return success(); 246 } 247 248 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 249 if (!scale) 250 return VectorType(); 251 if (auto et = getElementType().dyn_cast<IntegerType>()) 252 if (auto scaledEt = et.scaleElementBitwidth(scale)) 253 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 254 if (auto et = getElementType().dyn_cast<FloatType>()) 255 if (auto scaledEt = et.scaleElementBitwidth(scale)) 256 return VectorType::get(getShape(), scaledEt, getNumScalableDims()); 257 return VectorType(); 258 } 259 260 void VectorType::walkImmediateSubElements( 261 function_ref<void(Attribute)> walkAttrsFn, 262 function_ref<void(Type)> walkTypesFn) const { 263 walkTypesFn(getElementType()); 264 } 265 266 Type VectorType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 267 ArrayRef<Type> replTypes) const { 268 return get(getShape(), replTypes.front(), getNumScalableDims()); 269 } 270 271 VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 272 Type elementType) const { 273 return VectorType::get(shape.value_or(getShape()), elementType, 274 getNumScalableDims()); 275 } 276 277 //===----------------------------------------------------------------------===// 278 // TensorType 279 //===----------------------------------------------------------------------===// 280 281 Type TensorType::getElementType() const { 282 return llvm::TypeSwitch<TensorType, Type>(*this) 283 .Case<RankedTensorType, UnrankedTensorType>( 284 [](auto type) { return type.getElementType(); }); 285 } 286 287 bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); } 288 289 ArrayRef<int64_t> TensorType::getShape() const { 290 return cast<RankedTensorType>().getShape(); 291 } 292 293 TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape, 294 Type elementType) const { 295 if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) { 296 if (shape) 297 return RankedTensorType::get(*shape, elementType); 298 return UnrankedTensorType::get(elementType); 299 } 300 301 auto rankedTy = cast<RankedTensorType>(); 302 if (!shape) 303 return RankedTensorType::get(rankedTy.getShape(), elementType, 304 rankedTy.getEncoding()); 305 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, 306 rankedTy.getEncoding()); 307 } 308 309 // Check if "elementType" can be an element type of a tensor. 310 static LogicalResult 311 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 312 Type elementType) { 313 if (!TensorType::isValidElementType(elementType)) 314 return emitError() << "invalid tensor element type: " << elementType; 315 return success(); 316 } 317 318 /// Return true if the specified element type is ok in a tensor. 319 bool TensorType::isValidElementType(Type type) { 320 // Note: Non standard/builtin types are allowed to exist within tensor 321 // types. Dialects are expected to verify that tensor types have a valid 322 // element type within that dialect. 323 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 324 IndexType>() || 325 !llvm::isa<BuiltinDialect>(type.getDialect()); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // RankedTensorType 330 //===----------------------------------------------------------------------===// 331 332 LogicalResult 333 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 334 ArrayRef<int64_t> shape, Type elementType, 335 Attribute encoding) { 336 for (int64_t s : shape) 337 if (s < -1) 338 return emitError() << "invalid tensor dimension size"; 339 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 340 if (failed(v.verifyEncoding(shape, elementType, emitError))) 341 return failure(); 342 return checkTensorElementType(emitError, elementType); 343 } 344 345 void RankedTensorType::walkImmediateSubElements( 346 function_ref<void(Attribute)> walkAttrsFn, 347 function_ref<void(Type)> walkTypesFn) const { 348 walkTypesFn(getElementType()); 349 if (Attribute encoding = getEncoding()) 350 walkAttrsFn(encoding); 351 } 352 353 Type RankedTensorType::replaceImmediateSubElements( 354 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const { 355 return get(getShape(), replTypes.front(), 356 replAttrs.empty() ? Attribute() : replAttrs.back()); 357 } 358 359 //===----------------------------------------------------------------------===// 360 // UnrankedTensorType 361 //===----------------------------------------------------------------------===// 362 363 LogicalResult 364 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 365 Type elementType) { 366 return checkTensorElementType(emitError, elementType); 367 } 368 369 void UnrankedTensorType::walkImmediateSubElements( 370 function_ref<void(Attribute)> walkAttrsFn, 371 function_ref<void(Type)> walkTypesFn) const { 372 walkTypesFn(getElementType()); 373 } 374 375 Type UnrankedTensorType::replaceImmediateSubElements( 376 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const { 377 return get(replTypes.front()); 378 } 379 380 //===----------------------------------------------------------------------===// 381 // BaseMemRefType 382 //===----------------------------------------------------------------------===// 383 384 Type BaseMemRefType::getElementType() const { 385 return llvm::TypeSwitch<BaseMemRefType, Type>(*this) 386 .Case<MemRefType, UnrankedMemRefType>( 387 [](auto type) { return type.getElementType(); }); 388 } 389 390 bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); } 391 392 ArrayRef<int64_t> BaseMemRefType::getShape() const { 393 return cast<MemRefType>().getShape(); 394 } 395 396 BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape, 397 Type elementType) const { 398 if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) { 399 if (!shape) 400 return UnrankedMemRefType::get(elementType, getMemorySpace()); 401 MemRefType::Builder builder(*shape, elementType); 402 builder.setMemorySpace(getMemorySpace()); 403 return builder; 404 } 405 406 MemRefType::Builder builder(cast<MemRefType>()); 407 if (shape) 408 builder.setShape(*shape); 409 builder.setElementType(elementType); 410 return builder; 411 } 412 413 Attribute BaseMemRefType::getMemorySpace() const { 414 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 415 return rankedMemRefTy.getMemorySpace(); 416 return cast<UnrankedMemRefType>().getMemorySpace(); 417 } 418 419 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 420 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 421 return rankedMemRefTy.getMemorySpaceAsInt(); 422 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // MemRefType 427 //===----------------------------------------------------------------------===// 428 429 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 430 /// `originalShape` with some `1` entries erased, return the set of indices 431 /// that specifies which of the entries of `originalShape` are dropped to obtain 432 /// `reducedShape`. The returned mask can be applied as a projection to 433 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 434 /// which dimensions must be kept when e.g. compute MemRef strides under 435 /// rank-reducing operations. Return None if reducedShape cannot be obtained 436 /// by dropping only `1` entries in `originalShape`. 437 llvm::Optional<llvm::SmallDenseSet<unsigned>> 438 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 439 ArrayRef<int64_t> reducedShape) { 440 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 441 llvm::SmallDenseSet<unsigned> unusedDims; 442 unsigned reducedIdx = 0; 443 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 444 // Greedily insert `originalIdx` if match. 445 if (reducedIdx < reducedRank && 446 originalShape[originalIdx] == reducedShape[reducedIdx]) { 447 reducedIdx++; 448 continue; 449 } 450 451 unusedDims.insert(originalIdx); 452 // If no match on `originalIdx`, the `originalShape` at this dimension 453 // must be 1, otherwise we bail. 454 if (originalShape[originalIdx] != 1) 455 return llvm::None; 456 } 457 // The whole reducedShape must be scanned, otherwise we bail. 458 if (reducedIdx != reducedRank) 459 return llvm::None; 460 return unusedDims; 461 } 462 463 SliceVerificationResult 464 mlir::isRankReducedType(ShapedType originalType, 465 ShapedType candidateReducedType) { 466 if (originalType == candidateReducedType) 467 return SliceVerificationResult::Success; 468 469 ShapedType originalShapedType = originalType.cast<ShapedType>(); 470 ShapedType candidateReducedShapedType = 471 candidateReducedType.cast<ShapedType>(); 472 473 // Rank and size logic is valid for all ShapedTypes. 474 ArrayRef<int64_t> originalShape = originalShapedType.getShape(); 475 ArrayRef<int64_t> candidateReducedShape = 476 candidateReducedShapedType.getShape(); 477 unsigned originalRank = originalShape.size(), 478 candidateReducedRank = candidateReducedShape.size(); 479 if (candidateReducedRank > originalRank) 480 return SliceVerificationResult::RankTooLarge; 481 482 auto optionalUnusedDimsMask = 483 computeRankReductionMask(originalShape, candidateReducedShape); 484 485 // Sizes cannot be matched in case empty vector is returned. 486 if (!optionalUnusedDimsMask) 487 return SliceVerificationResult::SizeMismatch; 488 489 if (originalShapedType.getElementType() != 490 candidateReducedShapedType.getElementType()) 491 return SliceVerificationResult::ElemTypeMismatch; 492 493 return SliceVerificationResult::Success; 494 } 495 496 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 497 // Empty attribute is allowed as default memory space. 498 if (!memorySpace) 499 return true; 500 501 // Supported built-in attributes. 502 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 503 return true; 504 505 // Allow custom dialect attributes. 506 if (!isa<BuiltinDialect>(memorySpace.getDialect())) 507 return true; 508 509 return false; 510 } 511 512 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 513 MLIRContext *ctx) { 514 if (memorySpace == 0) 515 return nullptr; 516 517 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 518 } 519 520 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 521 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 522 if (intMemorySpace && intMemorySpace.getValue() == 0) 523 return nullptr; 524 525 return memorySpace; 526 } 527 528 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 529 if (!memorySpace) 530 return 0; 531 532 assert(memorySpace.isa<IntegerAttr>() && 533 "Using `getMemorySpaceInteger` with non-Integer attribute"); 534 535 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 536 } 537 538 MemRefType::Builder & 539 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 540 memorySpace = 541 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 542 return *this; 543 } 544 545 unsigned MemRefType::getMemorySpaceAsInt() const { 546 return detail::getMemorySpaceAsInt(getMemorySpace()); 547 } 548 549 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 550 MemRefLayoutAttrInterface layout, 551 Attribute memorySpace) { 552 // Use default layout for empty attribute. 553 if (!layout) 554 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 555 shape.size(), elementType.getContext())); 556 557 // Drop default memory space value and replace it with empty attribute. 558 memorySpace = skipDefaultMemorySpace(memorySpace); 559 560 return Base::get(elementType.getContext(), shape, elementType, layout, 561 memorySpace); 562 } 563 564 MemRefType MemRefType::getChecked( 565 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, 566 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { 567 568 // Use default layout for empty attribute. 569 if (!layout) 570 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( 571 shape.size(), elementType.getContext())); 572 573 // Drop default memory space value and replace it with empty attribute. 574 memorySpace = skipDefaultMemorySpace(memorySpace); 575 576 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 577 elementType, layout, memorySpace); 578 } 579 580 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 581 AffineMap map, Attribute memorySpace) { 582 583 // Use default layout for empty map. 584 if (!map) 585 map = AffineMap::getMultiDimIdentityMap(shape.size(), 586 elementType.getContext()); 587 588 // Wrap AffineMap into Attribute. 589 Attribute layout = AffineMapAttr::get(map); 590 591 // Drop default memory space value and replace it with empty attribute. 592 memorySpace = skipDefaultMemorySpace(memorySpace); 593 594 return Base::get(elementType.getContext(), shape, elementType, layout, 595 memorySpace); 596 } 597 598 MemRefType 599 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 600 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 601 Attribute memorySpace) { 602 603 // Use default layout for empty map. 604 if (!map) 605 map = AffineMap::getMultiDimIdentityMap(shape.size(), 606 elementType.getContext()); 607 608 // Wrap AffineMap into Attribute. 609 Attribute layout = AffineMapAttr::get(map); 610 611 // Drop default memory space value and replace it with empty attribute. 612 memorySpace = skipDefaultMemorySpace(memorySpace); 613 614 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 615 elementType, layout, memorySpace); 616 } 617 618 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 619 AffineMap map, unsigned memorySpaceInd) { 620 621 // Use default layout for empty map. 622 if (!map) 623 map = AffineMap::getMultiDimIdentityMap(shape.size(), 624 elementType.getContext()); 625 626 // Wrap AffineMap into Attribute. 627 Attribute layout = AffineMapAttr::get(map); 628 629 // Convert deprecated integer-like memory space to Attribute. 630 Attribute memorySpace = 631 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 632 633 return Base::get(elementType.getContext(), shape, elementType, layout, 634 memorySpace); 635 } 636 637 MemRefType 638 MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, 639 ArrayRef<int64_t> shape, Type elementType, AffineMap map, 640 unsigned memorySpaceInd) { 641 642 // Use default layout for empty map. 643 if (!map) 644 map = AffineMap::getMultiDimIdentityMap(shape.size(), 645 elementType.getContext()); 646 647 // Wrap AffineMap into Attribute. 648 Attribute layout = AffineMapAttr::get(map); 649 650 // Convert deprecated integer-like memory space to Attribute. 651 Attribute memorySpace = 652 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); 653 654 return Base::getChecked(emitErrorFn, elementType.getContext(), shape, 655 elementType, layout, memorySpace); 656 } 657 658 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 659 ArrayRef<int64_t> shape, Type elementType, 660 MemRefLayoutAttrInterface layout, 661 Attribute memorySpace) { 662 if (!BaseMemRefType::isValidElementType(elementType)) 663 return emitError() << "invalid memref element type"; 664 665 // Negative sizes are not allowed except for `-1` that means dynamic size. 666 for (int64_t s : shape) 667 if (s < -1) 668 return emitError() << "invalid memref size"; 669 670 assert(layout && "missing layout specification"); 671 if (failed(layout.verifyLayout(shape, emitError))) 672 return failure(); 673 674 if (!isSupportedMemorySpace(memorySpace)) 675 return emitError() << "unsupported memory space Attribute"; 676 677 return success(); 678 } 679 680 void MemRefType::walkImmediateSubElements( 681 function_ref<void(Attribute)> walkAttrsFn, 682 function_ref<void(Type)> walkTypesFn) const { 683 walkTypesFn(getElementType()); 684 if (!getLayout().isIdentity()) 685 walkAttrsFn(getLayout()); 686 walkAttrsFn(getMemorySpace()); 687 } 688 689 Type MemRefType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 690 ArrayRef<Type> replTypes) const { 691 bool hasLayout = replAttrs.size() > 1; 692 return get(getShape(), replTypes[0], 693 hasLayout ? replAttrs[0].dyn_cast<MemRefLayoutAttrInterface>() 694 : MemRefLayoutAttrInterface(), 695 hasLayout ? replAttrs[1] : replAttrs[0]); 696 } 697 698 //===----------------------------------------------------------------------===// 699 // UnrankedMemRefType 700 //===----------------------------------------------------------------------===// 701 702 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 703 return detail::getMemorySpaceAsInt(getMemorySpace()); 704 } 705 706 LogicalResult 707 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 708 Type elementType, Attribute memorySpace) { 709 if (!BaseMemRefType::isValidElementType(elementType)) 710 return emitError() << "invalid memref element type"; 711 712 if (!isSupportedMemorySpace(memorySpace)) 713 return emitError() << "unsupported memory space Attribute"; 714 715 return success(); 716 } 717 718 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 719 // i.e. single term). Accumulate the AffineExpr into the existing one. 720 static void extractStridesFromTerm(AffineExpr e, 721 AffineExpr multiplicativeFactor, 722 MutableArrayRef<AffineExpr> strides, 723 AffineExpr &offset) { 724 if (auto dim = e.dyn_cast<AffineDimExpr>()) 725 strides[dim.getPosition()] = 726 strides[dim.getPosition()] + multiplicativeFactor; 727 else 728 offset = offset + e * multiplicativeFactor; 729 } 730 731 /// Takes a single AffineExpr `e` and populates the `strides` array with the 732 /// strides expressions for each dim position. 733 /// The convention is that the strides for dimensions d0, .. dn appear in 734 /// order to make indexing intuitive into the result. 735 static LogicalResult extractStrides(AffineExpr e, 736 AffineExpr multiplicativeFactor, 737 MutableArrayRef<AffineExpr> strides, 738 AffineExpr &offset) { 739 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 740 if (!bin) { 741 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 742 return success(); 743 } 744 745 if (bin.getKind() == AffineExprKind::CeilDiv || 746 bin.getKind() == AffineExprKind::FloorDiv || 747 bin.getKind() == AffineExprKind::Mod) 748 return failure(); 749 750 if (bin.getKind() == AffineExprKind::Mul) { 751 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 752 if (dim) { 753 strides[dim.getPosition()] = 754 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 755 return success(); 756 } 757 // LHS and RHS may both contain complex expressions of dims. Try one path 758 // and if it fails try the other. This is guaranteed to succeed because 759 // only one path may have a `dim`, otherwise this is not an AffineExpr in 760 // the first place. 761 if (bin.getLHS().isSymbolicOrConstant()) 762 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 763 strides, offset); 764 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 765 strides, offset); 766 } 767 768 if (bin.getKind() == AffineExprKind::Add) { 769 auto res1 = 770 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 771 auto res2 = 772 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 773 return success(succeeded(res1) && succeeded(res2)); 774 } 775 776 llvm_unreachable("unexpected binary operation"); 777 } 778 779 LogicalResult mlir::getStridesAndOffset(MemRefType t, 780 SmallVectorImpl<AffineExpr> &strides, 781 AffineExpr &offset) { 782 AffineMap m = t.getLayout().getAffineMap(); 783 784 if (m.getNumResults() != 1 && !m.isIdentity()) 785 return failure(); 786 787 auto zero = getAffineConstantExpr(0, t.getContext()); 788 auto one = getAffineConstantExpr(1, t.getContext()); 789 offset = zero; 790 strides.assign(t.getRank(), zero); 791 792 // Canonical case for empty map. 793 if (m.isIdentity()) { 794 // 0-D corner case, offset is already 0. 795 if (t.getRank() == 0) 796 return success(); 797 auto stridedExpr = 798 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 799 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 800 return success(); 801 assert(false && "unexpected failure: extract strides in canonical layout"); 802 } 803 804 // Non-canonical case requires more work. 805 auto stridedExpr = 806 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 807 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 808 offset = AffineExpr(); 809 strides.clear(); 810 return failure(); 811 } 812 813 // Simplify results to allow folding to constants and simple checks. 814 unsigned numDims = m.getNumDims(); 815 unsigned numSymbols = m.getNumSymbols(); 816 offset = simplifyAffineExpr(offset, numDims, numSymbols); 817 for (auto &stride : strides) 818 stride = simplifyAffineExpr(stride, numDims, numSymbols); 819 820 /// In practice, a strided memref must be internally non-aliasing. Test 821 /// against 0 as a proxy. 822 /// TODO: static cases can have more advanced checks. 823 /// TODO: dynamic cases would require a way to compare symbolic 824 /// expressions and would probably need an affine set context propagated 825 /// everywhere. 826 if (llvm::any_of(strides, [](AffineExpr e) { 827 return e == getAffineConstantExpr(0, e.getContext()); 828 })) { 829 offset = AffineExpr(); 830 strides.clear(); 831 return failure(); 832 } 833 834 return success(); 835 } 836 837 LogicalResult mlir::getStridesAndOffset(MemRefType t, 838 SmallVectorImpl<int64_t> &strides, 839 int64_t &offset) { 840 AffineExpr offsetExpr; 841 SmallVector<AffineExpr, 4> strideExprs; 842 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 843 return failure(); 844 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 845 offset = cst.getValue(); 846 else 847 offset = ShapedType::kDynamicStrideOrOffset; 848 for (auto e : strideExprs) { 849 if (auto c = e.dyn_cast<AffineConstantExpr>()) 850 strides.push_back(c.getValue()); 851 else 852 strides.push_back(ShapedType::kDynamicStrideOrOffset); 853 } 854 return success(); 855 } 856 857 void UnrankedMemRefType::walkImmediateSubElements( 858 function_ref<void(Attribute)> walkAttrsFn, 859 function_ref<void(Type)> walkTypesFn) const { 860 walkTypesFn(getElementType()); 861 walkAttrsFn(getMemorySpace()); 862 } 863 864 Type UnrankedMemRefType::replaceImmediateSubElements( 865 ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const { 866 return get(replTypes.front(), replAttrs.front()); 867 } 868 869 //===----------------------------------------------------------------------===// 870 /// TupleType 871 //===----------------------------------------------------------------------===// 872 873 /// Return the elements types for this tuple. 874 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 875 876 /// Accumulate the types contained in this tuple and tuples nested within it. 877 /// Note that this only flattens nested tuples, not any other container type, 878 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 879 /// (i32, tensor<i32>, f32, i64) 880 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 881 for (Type type : getTypes()) { 882 if (auto nestedTuple = type.dyn_cast<TupleType>()) 883 nestedTuple.getFlattenedTypes(types); 884 else 885 types.push_back(type); 886 } 887 } 888 889 /// Return the number of element types. 890 size_t TupleType::size() const { return getImpl()->size(); } 891 892 void TupleType::walkImmediateSubElements( 893 function_ref<void(Attribute)> walkAttrsFn, 894 function_ref<void(Type)> walkTypesFn) const { 895 for (Type type : getTypes()) 896 walkTypesFn(type); 897 } 898 899 Type TupleType::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, 900 ArrayRef<Type> replTypes) const { 901 return get(getContext(), replTypes); 902 } 903 904 //===----------------------------------------------------------------------===// 905 // Type Utilities 906 //===----------------------------------------------------------------------===// 907 908 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 909 int64_t offset, 910 MLIRContext *context) { 911 AffineExpr expr; 912 unsigned nSymbols = 0; 913 914 // AffineExpr for offset. 915 // Static case. 916 if (offset != MemRefType::getDynamicStrideOrOffset()) { 917 auto cst = getAffineConstantExpr(offset, context); 918 expr = cst; 919 } else { 920 // Dynamic case, new symbol for the offset. 921 auto sym = getAffineSymbolExpr(nSymbols++, context); 922 expr = sym; 923 } 924 925 // AffineExpr for strides. 926 for (const auto &en : llvm::enumerate(strides)) { 927 auto dim = en.index(); 928 auto stride = en.value(); 929 assert(stride != 0 && "Invalid stride specification"); 930 auto d = getAffineDimExpr(dim, context); 931 AffineExpr mult; 932 // Static case. 933 if (stride != MemRefType::getDynamicStrideOrOffset()) 934 mult = getAffineConstantExpr(stride, context); 935 else 936 // Dynamic case, new symbol for each new stride. 937 mult = getAffineSymbolExpr(nSymbols++, context); 938 expr = expr + d * mult; 939 } 940 941 return AffineMap::get(strides.size(), nSymbols, expr); 942 } 943 944 /// Return a version of `t` with identity layout if it can be determined 945 /// statically that the layout is the canonical contiguous strided layout. 946 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 947 /// `t` with simplified layout. 948 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 949 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 950 AffineMap m = t.getLayout().getAffineMap(); 951 952 // Already in canonical form. 953 if (m.isIdentity()) 954 return t; 955 956 // Can't reduce to canonical identity form, return in canonical form. 957 if (m.getNumResults() > 1) 958 return t; 959 960 // Corner-case for 0-D affine maps. 961 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 962 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 963 if (cst.getValue() == 0) 964 return MemRefType::Builder(t).setLayout({}); 965 return t; 966 } 967 968 // 0-D corner case for empty shape that still have an affine map. Example: 969 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 970 // offset needs to remain, just return t. 971 if (t.getShape().empty()) 972 return t; 973 974 // If the canonical strided layout for the sizes of `t` is equal to the 975 // simplified layout of `t` we can just return an empty layout. Otherwise, 976 // just simplify the existing layout. 977 AffineExpr expr = 978 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 979 auto simplifiedLayoutExpr = 980 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 981 if (expr != simplifiedLayoutExpr) 982 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( 983 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); 984 return MemRefType::Builder(t).setLayout({}); 985 } 986 987 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 988 ArrayRef<AffineExpr> exprs, 989 MLIRContext *context) { 990 // Size 0 corner case is useful for canonicalizations. 991 if (sizes.empty() || llvm::is_contained(sizes, 0)) 992 return getAffineConstantExpr(0, context); 993 994 assert(!exprs.empty() && "expected exprs"); 995 auto maps = AffineMap::inferFromExprList(exprs); 996 assert(!maps.empty() && "Expected one non-empty map"); 997 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 998 999 AffineExpr expr; 1000 bool dynamicPoisonBit = false; 1001 int64_t runningSize = 1; 1002 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 1003 int64_t size = std::get<1>(en); 1004 // Degenerate case, no size =-> no stride 1005 if (size == 0) 1006 continue; 1007 AffineExpr dimExpr = std::get<0>(en); 1008 AffineExpr stride = dynamicPoisonBit 1009 ? getAffineSymbolExpr(nSymbols++, context) 1010 : getAffineConstantExpr(runningSize, context); 1011 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 1012 if (size > 0) { 1013 runningSize *= size; 1014 assert(runningSize > 0 && "integer overflow in size computation"); 1015 } else { 1016 dynamicPoisonBit = true; 1017 } 1018 } 1019 return simplifyAffineExpr(expr, numDims, nSymbols); 1020 } 1021 1022 /// Return a version of `t` with a layout that has all dynamic offset and 1023 /// strides. This is used to erase the static layout. 1024 MemRefType mlir::eraseStridedLayout(MemRefType t) { 1025 auto val = ShapedType::kDynamicStrideOrOffset; 1026 return MemRefType::Builder(t).setLayout( 1027 AffineMapAttr::get(makeStridedLinearLayoutMap( 1028 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()))); 1029 } 1030 1031 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 1032 MLIRContext *context) { 1033 SmallVector<AffineExpr, 4> exprs; 1034 exprs.reserve(sizes.size()); 1035 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 1036 exprs.push_back(getAffineDimExpr(dim, context)); 1037 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 1038 } 1039 1040 /// Return true if the layout for `t` is compatible with strided semantics. 1041 bool mlir::isStrided(MemRefType t) { 1042 int64_t offset; 1043 SmallVector<int64_t, 4> strides; 1044 auto res = getStridesAndOffset(t, strides, offset); 1045 return succeeded(res); 1046 } 1047 1048 /// Return the layout map in strided linear layout AffineMap form. 1049 /// Return null if the layout is not compatible with a strided layout. 1050 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 1051 int64_t offset; 1052 SmallVector<int64_t, 4> strides; 1053 if (failed(getStridesAndOffset(t, strides, offset))) 1054 return AffineMap(); 1055 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 1056 } 1057 1058 /// Return the AffineExpr representation of the offset, assuming `memRefType` 1059 /// is a strided memref. 1060 static AffineExpr getOffsetExpr(MemRefType memrefType) { 1061 SmallVector<AffineExpr> strides; 1062 AffineExpr offset; 1063 if (failed(getStridesAndOffset(memrefType, strides, offset))) 1064 assert(false && "expected strided memref"); 1065 return offset; 1066 } 1067 1068 /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and 1069 /// `offset` AffineExpr. 1070 static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, 1071 ArrayRef<int64_t> shape, 1072 Type elementType, 1073 AffineExpr offset) { 1074 AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); 1075 AffineExpr contiguousRowMajor = canonical + offset; 1076 AffineMap contiguousRowMajorMap = 1077 AffineMap::inferFromExprList({contiguousRowMajor})[0]; 1078 return MemRefType::get(shape, elementType, contiguousRowMajorMap); 1079 } 1080 1081 /// Helper determining if a memref is static-shape and contiguous-row-major 1082 /// layout, while still allowing for an arbitrary offset (any static or 1083 /// dynamic value). 1084 bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { 1085 if (!memrefType.hasStaticShape()) 1086 return false; 1087 AffineExpr offset = getOffsetExpr(memrefType); 1088 MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( 1089 memrefType.getContext(), memrefType.getShape(), 1090 memrefType.getElementType(), offset); 1091 return canonicalizeStridedLayout(memrefType) == 1092 canonicalizeStridedLayout(contiguousRowMajorMemRefType); 1093 } 1094