1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/TensorEncoding.h" 18 #include "llvm/ADT/APFloat.h" 19 #include "llvm/ADT/BitVector.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/Twine.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 24 using namespace mlir; 25 using namespace mlir::detail; 26 27 //===----------------------------------------------------------------------===// 28 /// Tablegen Type Definitions 29 //===----------------------------------------------------------------------===// 30 31 #define GET_TYPEDEF_CLASSES 32 #include "mlir/IR/BuiltinTypes.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 // BuiltinDialect 36 //===----------------------------------------------------------------------===// 37 38 void BuiltinDialect::registerTypes() { 39 addTypes< 40 #define GET_TYPEDEF_LIST 41 #include "mlir/IR/BuiltinTypes.cpp.inc" 42 >(); 43 } 44 45 //===----------------------------------------------------------------------===// 46 /// ComplexType 47 //===----------------------------------------------------------------------===// 48 49 /// Verify the construction of an integer type. 50 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 51 Type elementType) { 52 if (!elementType.isIntOrFloat()) 53 return emitError() << "invalid element type for complex"; 54 return success(); 55 } 56 57 //===----------------------------------------------------------------------===// 58 // Integer Type 59 //===----------------------------------------------------------------------===// 60 61 // static constexpr must have a definition (until in C++17 and inline variable). 62 constexpr unsigned IntegerType::kMaxWidth; 63 64 /// Verify the construction of an integer type. 65 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 66 unsigned width, 67 SignednessSemantics signedness) { 68 if (width > IntegerType::kMaxWidth) { 69 return emitError() << "integer bitwidth is limited to " 70 << IntegerType::kMaxWidth << " bits"; 71 } 72 return success(); 73 } 74 75 unsigned IntegerType::getWidth() const { return getImpl()->width; } 76 77 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 78 return getImpl()->signedness; 79 } 80 81 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 82 if (!scale) 83 return IntegerType(); 84 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 85 } 86 87 //===----------------------------------------------------------------------===// 88 // Float Type 89 //===----------------------------------------------------------------------===// 90 91 unsigned FloatType::getWidth() { 92 if (isa<Float16Type, BFloat16Type>()) 93 return 16; 94 if (isa<Float32Type>()) 95 return 32; 96 if (isa<Float64Type>()) 97 return 64; 98 if (isa<Float80Type>()) 99 return 80; 100 if (isa<Float128Type>()) 101 return 128; 102 llvm_unreachable("unexpected float type"); 103 } 104 105 /// Returns the floating semantics for the given type. 106 const llvm::fltSemantics &FloatType::getFloatSemantics() { 107 if (isa<BFloat16Type>()) 108 return APFloat::BFloat(); 109 if (isa<Float16Type>()) 110 return APFloat::IEEEhalf(); 111 if (isa<Float32Type>()) 112 return APFloat::IEEEsingle(); 113 if (isa<Float64Type>()) 114 return APFloat::IEEEdouble(); 115 if (isa<Float80Type>()) 116 return APFloat::x87DoubleExtended(); 117 if (isa<Float128Type>()) 118 return APFloat::IEEEquad(); 119 llvm_unreachable("non-floating point type used"); 120 } 121 122 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 123 if (!scale) 124 return FloatType(); 125 MLIRContext *ctx = getContext(); 126 if (isF16() || isBF16()) { 127 if (scale == 2) 128 return FloatType::getF32(ctx); 129 if (scale == 4) 130 return FloatType::getF64(ctx); 131 } 132 if (isF32()) 133 if (scale == 2) 134 return FloatType::getF64(ctx); 135 return FloatType(); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // FunctionType 140 //===----------------------------------------------------------------------===// 141 142 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 143 144 ArrayRef<Type> FunctionType::getInputs() const { 145 return getImpl()->getInputs(); 146 } 147 148 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 149 150 ArrayRef<Type> FunctionType::getResults() const { 151 return getImpl()->getResults(); 152 } 153 154 /// Helper to call a callback once on each index in the range 155 /// [0, `totalIndices`), *except* for the indices given in `indices`. 156 /// `indices` is allowed to have duplicates and can be in any order. 157 inline void iterateIndicesExcept(unsigned totalIndices, 158 ArrayRef<unsigned> indices, 159 function_ref<void(unsigned)> callback) { 160 llvm::BitVector skipIndices(totalIndices); 161 for (unsigned i : indices) 162 skipIndices.set(i); 163 164 for (unsigned i = 0; i < totalIndices; ++i) 165 if (!skipIndices.test(i)) 166 callback(i); 167 } 168 169 /// Returns a new function type without the specified arguments and results. 170 FunctionType 171 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 172 ArrayRef<unsigned> resultIndices) { 173 ArrayRef<Type> newInputTypes = getInputs(); 174 SmallVector<Type, 4> newInputTypesBuffer; 175 if (!argIndices.empty()) { 176 unsigned originalNumArgs = getNumInputs(); 177 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 178 newInputTypesBuffer.emplace_back(getInput(i)); 179 }); 180 newInputTypes = newInputTypesBuffer; 181 } 182 183 ArrayRef<Type> newResultTypes = getResults(); 184 SmallVector<Type, 4> newResultTypesBuffer; 185 if (!resultIndices.empty()) { 186 unsigned originalNumResults = getNumResults(); 187 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 188 newResultTypesBuffer.emplace_back(getResult(i)); 189 }); 190 newResultTypes = newResultTypesBuffer; 191 } 192 193 return get(getContext(), newInputTypes, newResultTypes); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // OpaqueType 198 //===----------------------------------------------------------------------===// 199 200 /// Verify the construction of an opaque type. 201 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 202 Identifier dialect, StringRef typeData) { 203 if (!Dialect::isValidNamespace(dialect.strref())) 204 return emitError() << "invalid dialect namespace '" << dialect << "'"; 205 206 // Check that the dialect is actually registered. 207 MLIRContext *context = dialect.getContext(); 208 if (!context->allowsUnregisteredDialects() && 209 !context->getLoadedDialect(dialect.strref())) { 210 return emitError() 211 << "`!" << dialect << "<\"" << typeData << "\">" 212 << "` type created with unregistered dialect. If this is " 213 "intended, please call allowUnregisteredDialects() on the " 214 "MLIRContext, or use -allow-unregistered-dialect with " 215 "mlir-opt"; 216 } 217 218 return success(); 219 } 220 221 //===----------------------------------------------------------------------===// 222 // ShapedType 223 //===----------------------------------------------------------------------===// 224 constexpr int64_t ShapedType::kDynamicSize; 225 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 226 227 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 228 if (auto other = dyn_cast<MemRefType>()) { 229 MemRefType::Builder b(other); 230 b.setShape(shape); 231 b.setElementType(elementType); 232 return b; 233 } 234 235 if (auto other = dyn_cast<UnrankedMemRefType>()) { 236 MemRefType::Builder b(shape, elementType); 237 b.setMemorySpace(other.getMemorySpace()); 238 return b; 239 } 240 241 if (isa<TensorType>()) 242 return RankedTensorType::get(shape, elementType); 243 244 if (isa<VectorType>()) 245 return VectorType::get(shape, elementType); 246 247 llvm_unreachable("Unhandled ShapedType clone case"); 248 } 249 250 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 251 if (auto other = dyn_cast<MemRefType>()) { 252 MemRefType::Builder b(other); 253 b.setShape(shape); 254 return b; 255 } 256 257 if (auto other = dyn_cast<UnrankedMemRefType>()) { 258 MemRefType::Builder b(shape, other.getElementType()); 259 b.setShape(shape); 260 b.setMemorySpace(other.getMemorySpace()); 261 return b; 262 } 263 264 if (isa<TensorType>()) 265 return RankedTensorType::get(shape, getElementType()); 266 267 if (isa<VectorType>()) 268 return VectorType::get(shape, getElementType()); 269 270 llvm_unreachable("Unhandled ShapedType clone case"); 271 } 272 273 ShapedType ShapedType::clone(Type elementType) { 274 if (auto other = dyn_cast<MemRefType>()) { 275 MemRefType::Builder b(other); 276 b.setElementType(elementType); 277 return b; 278 } 279 280 if (auto other = dyn_cast<UnrankedMemRefType>()) { 281 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 282 } 283 284 if (isa<TensorType>()) { 285 if (hasRank()) 286 return RankedTensorType::get(getShape(), elementType); 287 return UnrankedTensorType::get(elementType); 288 } 289 290 if (isa<VectorType>()) 291 return VectorType::get(getShape(), elementType); 292 293 llvm_unreachable("Unhandled ShapedType clone hit"); 294 } 295 296 Type ShapedType::getElementType() const { 297 return TypeSwitch<Type, Type>(*this) 298 .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType, 299 UnrankedMemRefType>([](auto ty) { return ty.getElementType(); }); 300 } 301 302 unsigned ShapedType::getElementTypeBitWidth() const { 303 return getElementType().getIntOrFloatBitWidth(); 304 } 305 306 int64_t ShapedType::getNumElements() const { 307 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 308 auto shape = getShape(); 309 int64_t num = 1; 310 for (auto dim : shape) { 311 num *= dim; 312 assert(num >= 0 && "integer overflow in element count computation"); 313 } 314 return num; 315 } 316 317 int64_t ShapedType::getRank() const { 318 assert(hasRank() && "cannot query rank of unranked shaped type"); 319 return getShape().size(); 320 } 321 322 bool ShapedType::hasRank() const { 323 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 324 } 325 326 int64_t ShapedType::getDimSize(unsigned idx) const { 327 assert(idx < getRank() && "invalid index for shaped type"); 328 return getShape()[idx]; 329 } 330 331 bool ShapedType::isDynamicDim(unsigned idx) const { 332 assert(idx < getRank() && "invalid index for shaped type"); 333 return isDynamic(getShape()[idx]); 334 } 335 336 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 337 assert(index < getRank() && "invalid index"); 338 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 339 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 340 } 341 342 /// Get the number of bits require to store a value of the given shaped type. 343 /// Compute the value recursively since tensors are allowed to have vectors as 344 /// elements. 345 int64_t ShapedType::getSizeInBits() const { 346 assert(hasStaticShape() && 347 "cannot get the bit size of an aggregate with a dynamic shape"); 348 349 auto elementType = getElementType(); 350 if (elementType.isIntOrFloat()) 351 return elementType.getIntOrFloatBitWidth() * getNumElements(); 352 353 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 354 elementType = complexType.getElementType(); 355 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 356 } 357 358 // Tensors can have vectors and other tensors as elements, other shaped types 359 // cannot. 360 assert(isa<TensorType>() && "unsupported element type"); 361 assert((elementType.isa<VectorType, TensorType>()) && 362 "unsupported tensor element type"); 363 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 364 } 365 366 ArrayRef<int64_t> ShapedType::getShape() const { 367 if (auto vectorType = dyn_cast<VectorType>()) 368 return vectorType.getShape(); 369 if (auto tensorType = dyn_cast<RankedTensorType>()) 370 return tensorType.getShape(); 371 return cast<MemRefType>().getShape(); 372 } 373 374 int64_t ShapedType::getNumDynamicDims() const { 375 return llvm::count_if(getShape(), isDynamic); 376 } 377 378 bool ShapedType::hasStaticShape() const { 379 return hasRank() && llvm::none_of(getShape(), isDynamic); 380 } 381 382 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 383 return hasStaticShape() && getShape() == shape; 384 } 385 386 //===----------------------------------------------------------------------===// 387 // VectorType 388 //===----------------------------------------------------------------------===// 389 390 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 391 ArrayRef<int64_t> shape, Type elementType) { 392 if (shape.empty()) 393 return emitError() << "vector types must have at least one dimension"; 394 395 if (!isValidElementType(elementType)) 396 return emitError() << "vector elements must be int/index/float type"; 397 398 if (any_of(shape, [](int64_t i) { return i <= 0; })) 399 return emitError() << "vector types must have positive constant sizes"; 400 401 return success(); 402 } 403 404 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 405 if (!scale) 406 return VectorType(); 407 if (auto et = getElementType().dyn_cast<IntegerType>()) 408 if (auto scaledEt = et.scaleElementBitwidth(scale)) 409 return VectorType::get(getShape(), scaledEt); 410 if (auto et = getElementType().dyn_cast<FloatType>()) 411 if (auto scaledEt = et.scaleElementBitwidth(scale)) 412 return VectorType::get(getShape(), scaledEt); 413 return VectorType(); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // TensorType 418 //===----------------------------------------------------------------------===// 419 420 // Check if "elementType" can be an element type of a tensor. 421 static LogicalResult 422 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 423 Type elementType) { 424 if (!TensorType::isValidElementType(elementType)) 425 return emitError() << "invalid tensor element type: " << elementType; 426 return success(); 427 } 428 429 /// Return true if the specified element type is ok in a tensor. 430 bool TensorType::isValidElementType(Type type) { 431 // Note: Non standard/builtin types are allowed to exist within tensor 432 // types. Dialects are expected to verify that tensor types have a valid 433 // element type within that dialect. 434 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 435 IndexType>() || 436 !type.getDialect().getNamespace().empty(); 437 } 438 439 //===----------------------------------------------------------------------===// 440 // RankedTensorType 441 //===----------------------------------------------------------------------===// 442 443 LogicalResult 444 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 445 ArrayRef<int64_t> shape, Type elementType, 446 Attribute encoding) { 447 for (int64_t s : shape) 448 if (s < -1) 449 return emitError() << "invalid tensor dimension size"; 450 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) 451 if (failed(v.verifyEncoding(shape, elementType, emitError))) 452 return failure(); 453 return checkTensorElementType(emitError, elementType); 454 } 455 456 //===----------------------------------------------------------------------===// 457 // UnrankedTensorType 458 //===----------------------------------------------------------------------===// 459 460 LogicalResult 461 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 462 Type elementType) { 463 return checkTensorElementType(emitError, elementType); 464 } 465 466 //===----------------------------------------------------------------------===// 467 // BaseMemRefType 468 //===----------------------------------------------------------------------===// 469 470 Attribute BaseMemRefType::getMemorySpace() const { 471 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 472 return rankedMemRefTy.getMemorySpace(); 473 return cast<UnrankedMemRefType>().getMemorySpace(); 474 } 475 476 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 477 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 478 return rankedMemRefTy.getMemorySpaceAsInt(); 479 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 480 } 481 482 //===----------------------------------------------------------------------===// 483 // MemRefType 484 //===----------------------------------------------------------------------===// 485 486 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 487 /// `originalShape` with some `1` entries erased, return the set of indices 488 /// that specifies which of the entries of `originalShape` are dropped to obtain 489 /// `reducedShape`. The returned mask can be applied as a projection to 490 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 491 /// which dimensions must be kept when e.g. compute MemRef strides under 492 /// rank-reducing operations. Return None if reducedShape cannot be obtained 493 /// by dropping only `1` entries in `originalShape`. 494 llvm::Optional<llvm::SmallDenseSet<unsigned>> 495 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 496 ArrayRef<int64_t> reducedShape) { 497 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 498 llvm::SmallDenseSet<unsigned> unusedDims; 499 unsigned reducedIdx = 0; 500 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 501 // Greedily insert `originalIdx` if no match. 502 if (reducedIdx < reducedRank && 503 originalShape[originalIdx] == reducedShape[reducedIdx]) { 504 reducedIdx++; 505 continue; 506 } 507 508 unusedDims.insert(originalIdx); 509 // If no match on `originalIdx`, the `originalShape` at this dimension 510 // must be 1, otherwise we bail. 511 if (originalShape[originalIdx] != 1) 512 return llvm::None; 513 } 514 // The whole reducedShape must be scanned, otherwise we bail. 515 if (reducedIdx != reducedRank) 516 return llvm::None; 517 return unusedDims; 518 } 519 520 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 521 // Empty attribute is allowed as default memory space. 522 if (!memorySpace) 523 return true; 524 525 // Supported built-in attributes. 526 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 527 return true; 528 529 // Allow custom dialect attributes. 530 if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect())) 531 return true; 532 533 return false; 534 } 535 536 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 537 MLIRContext *ctx) { 538 if (memorySpace == 0) 539 return nullptr; 540 541 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 542 } 543 544 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 545 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 546 if (intMemorySpace && intMemorySpace.getValue() == 0) 547 return nullptr; 548 549 return memorySpace; 550 } 551 552 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 553 if (!memorySpace) 554 return 0; 555 556 assert(memorySpace.isa<IntegerAttr>() && 557 "Using `getMemorySpaceInteger` with non-Integer attribute"); 558 559 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 560 } 561 562 MemRefType::Builder & 563 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 564 memorySpace = 565 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 566 return *this; 567 } 568 569 unsigned MemRefType::getMemorySpaceAsInt() const { 570 return detail::getMemorySpaceAsInt(getMemorySpace()); 571 } 572 573 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 574 ArrayRef<int64_t> shape, Type elementType, 575 ArrayRef<AffineMap> affineMapComposition, 576 Attribute memorySpace) { 577 if (!BaseMemRefType::isValidElementType(elementType)) 578 return emitError() << "invalid memref element type"; 579 580 // Negative sizes are not allowed except for `-1` that means dynamic size. 581 for (int64_t s : shape) 582 if (s < -1) 583 return emitError() << "invalid memref size"; 584 585 // Check that the structure of the composition is valid, i.e. that each 586 // subsequent affine map has as many inputs as the previous map has results. 587 // Take the dimensionality of the MemRef for the first map. 588 size_t dim = shape.size(); 589 for (auto it : llvm::enumerate(affineMapComposition)) { 590 AffineMap map = it.value(); 591 if (map.getNumDims() == dim) { 592 dim = map.getNumResults(); 593 continue; 594 } 595 return emitError() << "memref affine map dimension mismatch between " 596 << (it.index() == 0 ? Twine("memref rank") 597 : "affine map " + Twine(it.index())) 598 << " and affine map" << it.index() + 1 << ": " << dim 599 << " != " << map.getNumDims(); 600 } 601 602 if (!isSupportedMemorySpace(memorySpace)) { 603 return emitError() << "unsupported memory space Attribute"; 604 } 605 606 return success(); 607 } 608 609 //===----------------------------------------------------------------------===// 610 // UnrankedMemRefType 611 //===----------------------------------------------------------------------===// 612 613 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 614 return detail::getMemorySpaceAsInt(getMemorySpace()); 615 } 616 617 LogicalResult 618 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 619 Type elementType, Attribute memorySpace) { 620 if (!BaseMemRefType::isValidElementType(elementType)) 621 return emitError() << "invalid memref element type"; 622 623 if (!isSupportedMemorySpace(memorySpace)) 624 return emitError() << "unsupported memory space Attribute"; 625 626 return success(); 627 } 628 629 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 630 // i.e. single term). Accumulate the AffineExpr into the existing one. 631 static void extractStridesFromTerm(AffineExpr e, 632 AffineExpr multiplicativeFactor, 633 MutableArrayRef<AffineExpr> strides, 634 AffineExpr &offset) { 635 if (auto dim = e.dyn_cast<AffineDimExpr>()) 636 strides[dim.getPosition()] = 637 strides[dim.getPosition()] + multiplicativeFactor; 638 else 639 offset = offset + e * multiplicativeFactor; 640 } 641 642 /// Takes a single AffineExpr `e` and populates the `strides` array with the 643 /// strides expressions for each dim position. 644 /// The convention is that the strides for dimensions d0, .. dn appear in 645 /// order to make indexing intuitive into the result. 646 static LogicalResult extractStrides(AffineExpr e, 647 AffineExpr multiplicativeFactor, 648 MutableArrayRef<AffineExpr> strides, 649 AffineExpr &offset) { 650 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 651 if (!bin) { 652 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 653 return success(); 654 } 655 656 if (bin.getKind() == AffineExprKind::CeilDiv || 657 bin.getKind() == AffineExprKind::FloorDiv || 658 bin.getKind() == AffineExprKind::Mod) 659 return failure(); 660 661 if (bin.getKind() == AffineExprKind::Mul) { 662 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 663 if (dim) { 664 strides[dim.getPosition()] = 665 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 666 return success(); 667 } 668 // LHS and RHS may both contain complex expressions of dims. Try one path 669 // and if it fails try the other. This is guaranteed to succeed because 670 // only one path may have a `dim`, otherwise this is not an AffineExpr in 671 // the first place. 672 if (bin.getLHS().isSymbolicOrConstant()) 673 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 674 strides, offset); 675 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 676 strides, offset); 677 } 678 679 if (bin.getKind() == AffineExprKind::Add) { 680 auto res1 = 681 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 682 auto res2 = 683 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 684 return success(succeeded(res1) && succeeded(res2)); 685 } 686 687 llvm_unreachable("unexpected binary operation"); 688 } 689 690 LogicalResult mlir::getStridesAndOffset(MemRefType t, 691 SmallVectorImpl<AffineExpr> &strides, 692 AffineExpr &offset) { 693 auto affineMaps = t.getAffineMaps(); 694 695 if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1) 696 return failure(); 697 698 AffineMap m; 699 if (!affineMaps.empty()) { 700 m = affineMaps.back(); 701 for (size_t i = affineMaps.size() - 1; i > 0; --i) 702 m = m.compose(affineMaps[i - 1]); 703 assert(!m.isIdentity() && "unexpected identity map"); 704 } 705 706 auto zero = getAffineConstantExpr(0, t.getContext()); 707 auto one = getAffineConstantExpr(1, t.getContext()); 708 offset = zero; 709 strides.assign(t.getRank(), zero); 710 711 // Canonical case for empty map. 712 if (!m) { 713 // 0-D corner case, offset is already 0. 714 if (t.getRank() == 0) 715 return success(); 716 auto stridedExpr = 717 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 718 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 719 return success(); 720 assert(false && "unexpected failure: extract strides in canonical layout"); 721 } 722 723 // Non-canonical case requires more work. 724 auto stridedExpr = 725 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 726 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 727 offset = AffineExpr(); 728 strides.clear(); 729 return failure(); 730 } 731 732 // Simplify results to allow folding to constants and simple checks. 733 unsigned numDims = m.getNumDims(); 734 unsigned numSymbols = m.getNumSymbols(); 735 offset = simplifyAffineExpr(offset, numDims, numSymbols); 736 for (auto &stride : strides) 737 stride = simplifyAffineExpr(stride, numDims, numSymbols); 738 739 /// In practice, a strided memref must be internally non-aliasing. Test 740 /// against 0 as a proxy. 741 /// TODO: static cases can have more advanced checks. 742 /// TODO: dynamic cases would require a way to compare symbolic 743 /// expressions and would probably need an affine set context propagated 744 /// everywhere. 745 if (llvm::any_of(strides, [](AffineExpr e) { 746 return e == getAffineConstantExpr(0, e.getContext()); 747 })) { 748 offset = AffineExpr(); 749 strides.clear(); 750 return failure(); 751 } 752 753 return success(); 754 } 755 756 LogicalResult mlir::getStridesAndOffset(MemRefType t, 757 SmallVectorImpl<int64_t> &strides, 758 int64_t &offset) { 759 AffineExpr offsetExpr; 760 SmallVector<AffineExpr, 4> strideExprs; 761 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 762 return failure(); 763 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 764 offset = cst.getValue(); 765 else 766 offset = ShapedType::kDynamicStrideOrOffset; 767 for (auto e : strideExprs) { 768 if (auto c = e.dyn_cast<AffineConstantExpr>()) 769 strides.push_back(c.getValue()); 770 else 771 strides.push_back(ShapedType::kDynamicStrideOrOffset); 772 } 773 return success(); 774 } 775 776 //===----------------------------------------------------------------------===// 777 /// TupleType 778 //===----------------------------------------------------------------------===// 779 780 /// Return the elements types for this tuple. 781 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 782 783 /// Accumulate the types contained in this tuple and tuples nested within it. 784 /// Note that this only flattens nested tuples, not any other container type, 785 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 786 /// (i32, tensor<i32>, f32, i64) 787 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 788 for (Type type : getTypes()) { 789 if (auto nestedTuple = type.dyn_cast<TupleType>()) 790 nestedTuple.getFlattenedTypes(types); 791 else 792 types.push_back(type); 793 } 794 } 795 796 /// Return the number of element types. 797 size_t TupleType::size() const { return getImpl()->size(); } 798 799 //===----------------------------------------------------------------------===// 800 // Type Utilities 801 //===----------------------------------------------------------------------===// 802 803 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 804 int64_t offset, 805 MLIRContext *context) { 806 AffineExpr expr; 807 unsigned nSymbols = 0; 808 809 // AffineExpr for offset. 810 // Static case. 811 if (offset != MemRefType::getDynamicStrideOrOffset()) { 812 auto cst = getAffineConstantExpr(offset, context); 813 expr = cst; 814 } else { 815 // Dynamic case, new symbol for the offset. 816 auto sym = getAffineSymbolExpr(nSymbols++, context); 817 expr = sym; 818 } 819 820 // AffineExpr for strides. 821 for (auto en : llvm::enumerate(strides)) { 822 auto dim = en.index(); 823 auto stride = en.value(); 824 assert(stride != 0 && "Invalid stride specification"); 825 auto d = getAffineDimExpr(dim, context); 826 AffineExpr mult; 827 // Static case. 828 if (stride != MemRefType::getDynamicStrideOrOffset()) 829 mult = getAffineConstantExpr(stride, context); 830 else 831 // Dynamic case, new symbol for each new stride. 832 mult = getAffineSymbolExpr(nSymbols++, context); 833 expr = expr + d * mult; 834 } 835 836 return AffineMap::get(strides.size(), nSymbols, expr); 837 } 838 839 /// Return a version of `t` with identity layout if it can be determined 840 /// statically that the layout is the canonical contiguous strided layout. 841 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 842 /// `t` with simplified layout. 843 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 844 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 845 auto affineMaps = t.getAffineMaps(); 846 // Already in canonical form. 847 if (affineMaps.empty()) 848 return t; 849 850 // Can't reduce to canonical identity form, return in canonical form. 851 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 852 return t; 853 854 // Corner-case for 0-D affine maps. 855 auto m = affineMaps[0]; 856 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 857 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 858 if (cst.getValue() == 0) 859 return MemRefType::Builder(t).setAffineMaps({}); 860 return t; 861 } 862 863 // 0-D corner case for empty shape that still have an affine map. Example: 864 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 865 // offset needs to remain, just return t. 866 if (t.getShape().empty()) 867 return t; 868 869 // If the canonical strided layout for the sizes of `t` is equal to the 870 // simplified layout of `t` we can just return an empty layout. Otherwise, 871 // just simplify the existing layout. 872 AffineExpr expr = 873 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 874 auto simplifiedLayoutExpr = 875 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 876 if (expr != simplifiedLayoutExpr) 877 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 878 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 879 return MemRefType::Builder(t).setAffineMaps({}); 880 } 881 882 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 883 ArrayRef<AffineExpr> exprs, 884 MLIRContext *context) { 885 assert(!sizes.empty() && !exprs.empty() && 886 "expected non-empty sizes and exprs"); 887 888 // Size 0 corner case is useful for canonicalizations. 889 if (llvm::is_contained(sizes, 0)) 890 return getAffineConstantExpr(0, context); 891 892 auto maps = AffineMap::inferFromExprList(exprs); 893 assert(!maps.empty() && "Expected one non-empty map"); 894 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 895 896 AffineExpr expr; 897 bool dynamicPoisonBit = false; 898 int64_t runningSize = 1; 899 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 900 int64_t size = std::get<1>(en); 901 // Degenerate case, no size =-> no stride 902 if (size == 0) 903 continue; 904 AffineExpr dimExpr = std::get<0>(en); 905 AffineExpr stride = dynamicPoisonBit 906 ? getAffineSymbolExpr(nSymbols++, context) 907 : getAffineConstantExpr(runningSize, context); 908 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 909 if (size > 0) { 910 runningSize *= size; 911 assert(runningSize > 0 && "integer overflow in size computation"); 912 } else { 913 dynamicPoisonBit = true; 914 } 915 } 916 return simplifyAffineExpr(expr, numDims, nSymbols); 917 } 918 919 /// Return a version of `t` with a layout that has all dynamic offset and 920 /// strides. This is used to erase the static layout. 921 MemRefType mlir::eraseStridedLayout(MemRefType t) { 922 auto val = ShapedType::kDynamicStrideOrOffset; 923 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 924 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 925 } 926 927 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 928 MLIRContext *context) { 929 SmallVector<AffineExpr, 4> exprs; 930 exprs.reserve(sizes.size()); 931 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 932 exprs.push_back(getAffineDimExpr(dim, context)); 933 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 934 } 935 936 /// Return true if the layout for `t` is compatible with strided semantics. 937 bool mlir::isStrided(MemRefType t) { 938 int64_t offset; 939 SmallVector<int64_t, 4> strides; 940 auto res = getStridesAndOffset(t, strides, offset); 941 return succeeded(res); 942 } 943 944 /// Return the layout map in strided linear layout AffineMap form. 945 /// Return null if the layout is not compatible with a strided layout. 946 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 947 int64_t offset; 948 SmallVector<int64_t, 4> strides; 949 if (failed(getStridesAndOffset(t, strides, offset))) 950 return AffineMap(); 951 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 952 } 953