1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "llvm/ADT/APFloat.h" 18 #include "llvm/ADT/BitVector.h" 19 #include "llvm/ADT/Sequence.h" 20 #include "llvm/ADT/Twine.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 23 using namespace mlir; 24 using namespace mlir::detail; 25 26 //===----------------------------------------------------------------------===// 27 /// Tablegen Type Definitions 28 //===----------------------------------------------------------------------===// 29 30 #define GET_TYPEDEF_CLASSES 31 #include "mlir/IR/BuiltinTypes.cpp.inc" 32 33 //===----------------------------------------------------------------------===// 34 // BuiltinDialect 35 //===----------------------------------------------------------------------===// 36 37 void BuiltinDialect::registerTypes() { 38 addTypes< 39 #define GET_TYPEDEF_LIST 40 #include "mlir/IR/BuiltinTypes.cpp.inc" 41 >(); 42 } 43 44 //===----------------------------------------------------------------------===// 45 /// ComplexType 46 //===----------------------------------------------------------------------===// 47 48 /// Verify the construction of an integer type. 49 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 50 Type elementType) { 51 if (!elementType.isIntOrFloat()) 52 return emitError() << "invalid element type for complex"; 53 return success(); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // Integer Type 58 //===----------------------------------------------------------------------===// 59 60 // static constexpr must have a definition (until in C++17 and inline variable). 61 constexpr unsigned IntegerType::kMaxWidth; 62 63 /// Verify the construction of an integer type. 64 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 65 unsigned width, 66 SignednessSemantics signedness) { 67 if (width > IntegerType::kMaxWidth) { 68 return emitError() << "integer bitwidth is limited to " 69 << IntegerType::kMaxWidth << " bits"; 70 } 71 return success(); 72 } 73 74 unsigned IntegerType::getWidth() const { return getImpl()->width; } 75 76 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 77 return getImpl()->signedness; 78 } 79 80 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 81 if (!scale) 82 return IntegerType(); 83 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 84 } 85 86 //===----------------------------------------------------------------------===// 87 // Float Type 88 //===----------------------------------------------------------------------===// 89 90 unsigned FloatType::getWidth() { 91 if (isa<Float16Type, BFloat16Type>()) 92 return 16; 93 if (isa<Float32Type>()) 94 return 32; 95 if (isa<Float64Type>()) 96 return 64; 97 if (isa<Float80Type>()) 98 return 80; 99 if (isa<Float128Type>()) 100 return 128; 101 llvm_unreachable("unexpected float type"); 102 } 103 104 /// Returns the floating semantics for the given type. 105 const llvm::fltSemantics &FloatType::getFloatSemantics() { 106 if (isa<BFloat16Type>()) 107 return APFloat::BFloat(); 108 if (isa<Float16Type>()) 109 return APFloat::IEEEhalf(); 110 if (isa<Float32Type>()) 111 return APFloat::IEEEsingle(); 112 if (isa<Float64Type>()) 113 return APFloat::IEEEdouble(); 114 if (isa<Float80Type>()) 115 return APFloat::x87DoubleExtended(); 116 if (isa<Float128Type>()) 117 return APFloat::IEEEquad(); 118 llvm_unreachable("non-floating point type used"); 119 } 120 121 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 122 if (!scale) 123 return FloatType(); 124 MLIRContext *ctx = getContext(); 125 if (isF16() || isBF16()) { 126 if (scale == 2) 127 return FloatType::getF32(ctx); 128 if (scale == 4) 129 return FloatType::getF64(ctx); 130 } 131 if (isF32()) 132 if (scale == 2) 133 return FloatType::getF64(ctx); 134 return FloatType(); 135 } 136 137 //===----------------------------------------------------------------------===// 138 // FunctionType 139 //===----------------------------------------------------------------------===// 140 141 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 142 143 ArrayRef<Type> FunctionType::getInputs() const { 144 return getImpl()->getInputs(); 145 } 146 147 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 148 149 ArrayRef<Type> FunctionType::getResults() const { 150 return getImpl()->getResults(); 151 } 152 153 /// Helper to call a callback once on each index in the range 154 /// [0, `totalIndices`), *except* for the indices given in `indices`. 155 /// `indices` is allowed to have duplicates and can be in any order. 156 inline void iterateIndicesExcept(unsigned totalIndices, 157 ArrayRef<unsigned> indices, 158 function_ref<void(unsigned)> callback) { 159 llvm::BitVector skipIndices(totalIndices); 160 for (unsigned i : indices) 161 skipIndices.set(i); 162 163 for (unsigned i = 0; i < totalIndices; ++i) 164 if (!skipIndices.test(i)) 165 callback(i); 166 } 167 168 /// Returns a new function type without the specified arguments and results. 169 FunctionType 170 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 171 ArrayRef<unsigned> resultIndices) { 172 ArrayRef<Type> newInputTypes = getInputs(); 173 SmallVector<Type, 4> newInputTypesBuffer; 174 if (!argIndices.empty()) { 175 unsigned originalNumArgs = getNumInputs(); 176 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 177 newInputTypesBuffer.emplace_back(getInput(i)); 178 }); 179 newInputTypes = newInputTypesBuffer; 180 } 181 182 ArrayRef<Type> newResultTypes = getResults(); 183 SmallVector<Type, 4> newResultTypesBuffer; 184 if (!resultIndices.empty()) { 185 unsigned originalNumResults = getNumResults(); 186 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 187 newResultTypesBuffer.emplace_back(getResult(i)); 188 }); 189 newResultTypes = newResultTypesBuffer; 190 } 191 192 return get(getContext(), newInputTypes, newResultTypes); 193 } 194 195 //===----------------------------------------------------------------------===// 196 // OpaqueType 197 //===----------------------------------------------------------------------===// 198 199 /// Verify the construction of an opaque type. 200 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 201 Identifier dialect, StringRef typeData) { 202 if (!Dialect::isValidNamespace(dialect.strref())) 203 return emitError() << "invalid dialect namespace '" << dialect << "'"; 204 205 // Check that the dialect is actually registered. 206 MLIRContext *context = dialect.getContext(); 207 if (!context->allowsUnregisteredDialects() && 208 !context->getLoadedDialect(dialect.strref())) { 209 return emitError() 210 << "`!" << dialect << "<\"" << typeData << "\">" 211 << "` type created with unregistered dialect. If this is " 212 "intended, please call allowUnregisteredDialects() on the " 213 "MLIRContext, or use -allow-unregistered-dialect with " 214 "mlir-opt"; 215 } 216 217 return success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // ShapedType 222 //===----------------------------------------------------------------------===// 223 constexpr int64_t ShapedType::kDynamicSize; 224 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 225 226 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 227 if (auto other = dyn_cast<MemRefType>()) { 228 MemRefType::Builder b(other); 229 b.setShape(shape); 230 b.setElementType(elementType); 231 return b; 232 } 233 234 if (auto other = dyn_cast<UnrankedMemRefType>()) { 235 MemRefType::Builder b(shape, elementType); 236 b.setMemorySpace(other.getMemorySpace()); 237 return b; 238 } 239 240 if (isa<TensorType>()) 241 return RankedTensorType::get(shape, elementType); 242 243 if (isa<VectorType>()) 244 return VectorType::get(shape, elementType); 245 246 llvm_unreachable("Unhandled ShapedType clone case"); 247 } 248 249 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 250 if (auto other = dyn_cast<MemRefType>()) { 251 MemRefType::Builder b(other); 252 b.setShape(shape); 253 return b; 254 } 255 256 if (auto other = dyn_cast<UnrankedMemRefType>()) { 257 MemRefType::Builder b(shape, other.getElementType()); 258 b.setShape(shape); 259 b.setMemorySpace(other.getMemorySpace()); 260 return b; 261 } 262 263 if (isa<TensorType>()) 264 return RankedTensorType::get(shape, getElementType()); 265 266 if (isa<VectorType>()) 267 return VectorType::get(shape, getElementType()); 268 269 llvm_unreachable("Unhandled ShapedType clone case"); 270 } 271 272 ShapedType ShapedType::clone(Type elementType) { 273 if (auto other = dyn_cast<MemRefType>()) { 274 MemRefType::Builder b(other); 275 b.setElementType(elementType); 276 return b; 277 } 278 279 if (auto other = dyn_cast<UnrankedMemRefType>()) { 280 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 281 } 282 283 if (isa<TensorType>()) { 284 if (hasRank()) 285 return RankedTensorType::get(getShape(), elementType); 286 return UnrankedTensorType::get(elementType); 287 } 288 289 if (isa<VectorType>()) 290 return VectorType::get(getShape(), elementType); 291 292 llvm_unreachable("Unhandled ShapedType clone hit"); 293 } 294 295 Type ShapedType::getElementType() const { 296 return TypeSwitch<Type, Type>(*this) 297 .Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType, 298 UnrankedMemRefType>([](auto ty) { return ty.getElementType(); }); 299 } 300 301 unsigned ShapedType::getElementTypeBitWidth() const { 302 return getElementType().getIntOrFloatBitWidth(); 303 } 304 305 int64_t ShapedType::getNumElements() const { 306 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 307 auto shape = getShape(); 308 int64_t num = 1; 309 for (auto dim : shape) { 310 num *= dim; 311 assert(num >= 0 && "integer overflow in element count computation"); 312 } 313 return num; 314 } 315 316 int64_t ShapedType::getRank() const { 317 assert(hasRank() && "cannot query rank of unranked shaped type"); 318 return getShape().size(); 319 } 320 321 bool ShapedType::hasRank() const { 322 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 323 } 324 325 int64_t ShapedType::getDimSize(unsigned idx) const { 326 assert(idx < getRank() && "invalid index for shaped type"); 327 return getShape()[idx]; 328 } 329 330 bool ShapedType::isDynamicDim(unsigned idx) const { 331 assert(idx < getRank() && "invalid index for shaped type"); 332 return isDynamic(getShape()[idx]); 333 } 334 335 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 336 assert(index < getRank() && "invalid index"); 337 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 338 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 339 } 340 341 /// Get the number of bits require to store a value of the given shaped type. 342 /// Compute the value recursively since tensors are allowed to have vectors as 343 /// elements. 344 int64_t ShapedType::getSizeInBits() const { 345 assert(hasStaticShape() && 346 "cannot get the bit size of an aggregate with a dynamic shape"); 347 348 auto elementType = getElementType(); 349 if (elementType.isIntOrFloat()) 350 return elementType.getIntOrFloatBitWidth() * getNumElements(); 351 352 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 353 elementType = complexType.getElementType(); 354 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 355 } 356 357 // Tensors can have vectors and other tensors as elements, other shaped types 358 // cannot. 359 assert(isa<TensorType>() && "unsupported element type"); 360 assert((elementType.isa<VectorType, TensorType>()) && 361 "unsupported tensor element type"); 362 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 363 } 364 365 ArrayRef<int64_t> ShapedType::getShape() const { 366 if (auto vectorType = dyn_cast<VectorType>()) 367 return vectorType.getShape(); 368 if (auto tensorType = dyn_cast<RankedTensorType>()) 369 return tensorType.getShape(); 370 return cast<MemRefType>().getShape(); 371 } 372 373 int64_t ShapedType::getNumDynamicDims() const { 374 return llvm::count_if(getShape(), isDynamic); 375 } 376 377 bool ShapedType::hasStaticShape() const { 378 return hasRank() && llvm::none_of(getShape(), isDynamic); 379 } 380 381 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 382 return hasStaticShape() && getShape() == shape; 383 } 384 385 //===----------------------------------------------------------------------===// 386 // VectorType 387 //===----------------------------------------------------------------------===// 388 389 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 390 ArrayRef<int64_t> shape, Type elementType) { 391 if (shape.empty()) 392 return emitError() << "vector types must have at least one dimension"; 393 394 if (!isValidElementType(elementType)) 395 return emitError() << "vector elements must be int/index/float type"; 396 397 if (any_of(shape, [](int64_t i) { return i <= 0; })) 398 return emitError() << "vector types must have positive constant sizes"; 399 400 return success(); 401 } 402 403 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 404 if (!scale) 405 return VectorType(); 406 if (auto et = getElementType().dyn_cast<IntegerType>()) 407 if (auto scaledEt = et.scaleElementBitwidth(scale)) 408 return VectorType::get(getShape(), scaledEt); 409 if (auto et = getElementType().dyn_cast<FloatType>()) 410 if (auto scaledEt = et.scaleElementBitwidth(scale)) 411 return VectorType::get(getShape(), scaledEt); 412 return VectorType(); 413 } 414 415 //===----------------------------------------------------------------------===// 416 // TensorType 417 //===----------------------------------------------------------------------===// 418 419 // Check if "elementType" can be an element type of a tensor. 420 static LogicalResult 421 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 422 Type elementType) { 423 if (!TensorType::isValidElementType(elementType)) 424 return emitError() << "invalid tensor element type: " << elementType; 425 return success(); 426 } 427 428 /// Return true if the specified element type is ok in a tensor. 429 bool TensorType::isValidElementType(Type type) { 430 // Note: Non standard/builtin types are allowed to exist within tensor 431 // types. Dialects are expected to verify that tensor types have a valid 432 // element type within that dialect. 433 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 434 IndexType>() || 435 !type.getDialect().getNamespace().empty(); 436 } 437 438 //===----------------------------------------------------------------------===// 439 // RankedTensorType 440 //===----------------------------------------------------------------------===// 441 442 LogicalResult 443 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 444 ArrayRef<int64_t> shape, Type elementType, 445 Attribute encoding) { 446 for (int64_t s : shape) 447 if (s < -1) 448 return emitError() << "invalid tensor dimension size"; 449 // TODO: verify contents of encoding attribute. 450 return checkTensorElementType(emitError, elementType); 451 } 452 453 //===----------------------------------------------------------------------===// 454 // UnrankedTensorType 455 //===----------------------------------------------------------------------===// 456 457 LogicalResult 458 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 459 Type elementType) { 460 return checkTensorElementType(emitError, elementType); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // BaseMemRefType 465 //===----------------------------------------------------------------------===// 466 467 Attribute BaseMemRefType::getMemorySpace() const { 468 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 469 return rankedMemRefTy.getMemorySpace(); 470 return cast<UnrankedMemRefType>().getMemorySpace(); 471 } 472 473 unsigned BaseMemRefType::getMemorySpaceAsInt() const { 474 if (auto rankedMemRefTy = dyn_cast<MemRefType>()) 475 return rankedMemRefTy.getMemorySpaceAsInt(); 476 return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); 477 } 478 479 //===----------------------------------------------------------------------===// 480 // MemRefType 481 //===----------------------------------------------------------------------===// 482 483 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of 484 /// `originalShape` with some `1` entries erased, return the set of indices 485 /// that specifies which of the entries of `originalShape` are dropped to obtain 486 /// `reducedShape`. The returned mask can be applied as a projection to 487 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track 488 /// which dimensions must be kept when e.g. compute MemRef strides under 489 /// rank-reducing operations. Return None if reducedShape cannot be obtained 490 /// by dropping only `1` entries in `originalShape`. 491 llvm::Optional<llvm::SmallDenseSet<unsigned>> 492 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, 493 ArrayRef<int64_t> reducedShape) { 494 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); 495 llvm::SmallDenseSet<unsigned> unusedDims; 496 unsigned reducedIdx = 0; 497 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { 498 // Greedily insert `originalIdx` if no match. 499 if (reducedIdx < reducedRank && 500 originalShape[originalIdx] == reducedShape[reducedIdx]) { 501 reducedIdx++; 502 continue; 503 } 504 505 unusedDims.insert(originalIdx); 506 // If no match on `originalIdx`, the `originalShape` at this dimension 507 // must be 1, otherwise we bail. 508 if (originalShape[originalIdx] != 1) 509 return llvm::None; 510 } 511 // The whole reducedShape must be scanned, otherwise we bail. 512 if (reducedIdx != reducedRank) 513 return llvm::None; 514 return unusedDims; 515 } 516 517 bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { 518 // Empty attribute is allowed as default memory space. 519 if (!memorySpace) 520 return true; 521 522 // Supported built-in attributes. 523 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) 524 return true; 525 526 // Allow custom dialect attributes. 527 if (!::mlir::isa<BuiltinDialect>(memorySpace.getDialect())) 528 return true; 529 530 return false; 531 } 532 533 Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, 534 MLIRContext *ctx) { 535 if (memorySpace == 0) 536 return nullptr; 537 538 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); 539 } 540 541 Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { 542 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); 543 if (intMemorySpace && intMemorySpace.getValue() == 0) 544 return nullptr; 545 546 return memorySpace; 547 } 548 549 unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { 550 if (!memorySpace) 551 return 0; 552 553 assert(memorySpace.isa<IntegerAttr>() && 554 "Using `getMemorySpaceInteger` with non-Integer attribute"); 555 556 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); 557 } 558 559 MemRefType::Builder & 560 MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { 561 memorySpace = 562 wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); 563 return *this; 564 } 565 566 unsigned MemRefType::getMemorySpaceAsInt() const { 567 return detail::getMemorySpaceAsInt(getMemorySpace()); 568 } 569 570 LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 571 ArrayRef<int64_t> shape, Type elementType, 572 ArrayRef<AffineMap> affineMapComposition, 573 Attribute memorySpace) { 574 if (!BaseMemRefType::isValidElementType(elementType)) 575 return emitError() << "invalid memref element type"; 576 577 // Negative sizes are not allowed except for `-1` that means dynamic size. 578 for (int64_t s : shape) 579 if (s < -1) 580 return emitError() << "invalid memref size"; 581 582 // Check that the structure of the composition is valid, i.e. that each 583 // subsequent affine map has as many inputs as the previous map has results. 584 // Take the dimensionality of the MemRef for the first map. 585 size_t dim = shape.size(); 586 for (auto it : llvm::enumerate(affineMapComposition)) { 587 AffineMap map = it.value(); 588 if (map.getNumDims() == dim) { 589 dim = map.getNumResults(); 590 continue; 591 } 592 return emitError() << "memref affine map dimension mismatch between " 593 << (it.index() == 0 ? Twine("memref rank") 594 : "affine map " + Twine(it.index())) 595 << " and affine map" << it.index() + 1 << ": " << dim 596 << " != " << map.getNumDims(); 597 } 598 599 if (!isSupportedMemorySpace(memorySpace)) { 600 return emitError() << "unsupported memory space Attribute"; 601 } 602 603 return success(); 604 } 605 606 //===----------------------------------------------------------------------===// 607 // UnrankedMemRefType 608 //===----------------------------------------------------------------------===// 609 610 unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { 611 return detail::getMemorySpaceAsInt(getMemorySpace()); 612 } 613 614 LogicalResult 615 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 616 Type elementType, Attribute memorySpace) { 617 if (!BaseMemRefType::isValidElementType(elementType)) 618 return emitError() << "invalid memref element type"; 619 620 if (!isSupportedMemorySpace(memorySpace)) 621 return emitError() << "unsupported memory space Attribute"; 622 623 return success(); 624 } 625 626 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 627 // i.e. single term). Accumulate the AffineExpr into the existing one. 628 static void extractStridesFromTerm(AffineExpr e, 629 AffineExpr multiplicativeFactor, 630 MutableArrayRef<AffineExpr> strides, 631 AffineExpr &offset) { 632 if (auto dim = e.dyn_cast<AffineDimExpr>()) 633 strides[dim.getPosition()] = 634 strides[dim.getPosition()] + multiplicativeFactor; 635 else 636 offset = offset + e * multiplicativeFactor; 637 } 638 639 /// Takes a single AffineExpr `e` and populates the `strides` array with the 640 /// strides expressions for each dim position. 641 /// The convention is that the strides for dimensions d0, .. dn appear in 642 /// order to make indexing intuitive into the result. 643 static LogicalResult extractStrides(AffineExpr e, 644 AffineExpr multiplicativeFactor, 645 MutableArrayRef<AffineExpr> strides, 646 AffineExpr &offset) { 647 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 648 if (!bin) { 649 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 650 return success(); 651 } 652 653 if (bin.getKind() == AffineExprKind::CeilDiv || 654 bin.getKind() == AffineExprKind::FloorDiv || 655 bin.getKind() == AffineExprKind::Mod) 656 return failure(); 657 658 if (bin.getKind() == AffineExprKind::Mul) { 659 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 660 if (dim) { 661 strides[dim.getPosition()] = 662 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 663 return success(); 664 } 665 // LHS and RHS may both contain complex expressions of dims. Try one path 666 // and if it fails try the other. This is guaranteed to succeed because 667 // only one path may have a `dim`, otherwise this is not an AffineExpr in 668 // the first place. 669 if (bin.getLHS().isSymbolicOrConstant()) 670 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 671 strides, offset); 672 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 673 strides, offset); 674 } 675 676 if (bin.getKind() == AffineExprKind::Add) { 677 auto res1 = 678 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 679 auto res2 = 680 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 681 return success(succeeded(res1) && succeeded(res2)); 682 } 683 684 llvm_unreachable("unexpected binary operation"); 685 } 686 687 LogicalResult mlir::getStridesAndOffset(MemRefType t, 688 SmallVectorImpl<AffineExpr> &strides, 689 AffineExpr &offset) { 690 auto affineMaps = t.getAffineMaps(); 691 692 if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1) 693 return failure(); 694 695 AffineMap m; 696 if (!affineMaps.empty()) { 697 m = affineMaps.back(); 698 for (size_t i = affineMaps.size() - 1; i > 0; --i) 699 m = m.compose(affineMaps[i - 1]); 700 assert(!m.isIdentity() && "unexpected identity map"); 701 } 702 703 auto zero = getAffineConstantExpr(0, t.getContext()); 704 auto one = getAffineConstantExpr(1, t.getContext()); 705 offset = zero; 706 strides.assign(t.getRank(), zero); 707 708 // Canonical case for empty map. 709 if (!m) { 710 // 0-D corner case, offset is already 0. 711 if (t.getRank() == 0) 712 return success(); 713 auto stridedExpr = 714 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 715 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 716 return success(); 717 assert(false && "unexpected failure: extract strides in canonical layout"); 718 } 719 720 // Non-canonical case requires more work. 721 auto stridedExpr = 722 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 723 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 724 offset = AffineExpr(); 725 strides.clear(); 726 return failure(); 727 } 728 729 // Simplify results to allow folding to constants and simple checks. 730 unsigned numDims = m.getNumDims(); 731 unsigned numSymbols = m.getNumSymbols(); 732 offset = simplifyAffineExpr(offset, numDims, numSymbols); 733 for (auto &stride : strides) 734 stride = simplifyAffineExpr(stride, numDims, numSymbols); 735 736 /// In practice, a strided memref must be internally non-aliasing. Test 737 /// against 0 as a proxy. 738 /// TODO: static cases can have more advanced checks. 739 /// TODO: dynamic cases would require a way to compare symbolic 740 /// expressions and would probably need an affine set context propagated 741 /// everywhere. 742 if (llvm::any_of(strides, [](AffineExpr e) { 743 return e == getAffineConstantExpr(0, e.getContext()); 744 })) { 745 offset = AffineExpr(); 746 strides.clear(); 747 return failure(); 748 } 749 750 return success(); 751 } 752 753 LogicalResult mlir::getStridesAndOffset(MemRefType t, 754 SmallVectorImpl<int64_t> &strides, 755 int64_t &offset) { 756 AffineExpr offsetExpr; 757 SmallVector<AffineExpr, 4> strideExprs; 758 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 759 return failure(); 760 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 761 offset = cst.getValue(); 762 else 763 offset = ShapedType::kDynamicStrideOrOffset; 764 for (auto e : strideExprs) { 765 if (auto c = e.dyn_cast<AffineConstantExpr>()) 766 strides.push_back(c.getValue()); 767 else 768 strides.push_back(ShapedType::kDynamicStrideOrOffset); 769 } 770 return success(); 771 } 772 773 //===----------------------------------------------------------------------===// 774 /// TupleType 775 //===----------------------------------------------------------------------===// 776 777 /// Return the elements types for this tuple. 778 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 779 780 /// Accumulate the types contained in this tuple and tuples nested within it. 781 /// Note that this only flattens nested tuples, not any other container type, 782 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 783 /// (i32, tensor<i32>, f32, i64) 784 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 785 for (Type type : getTypes()) { 786 if (auto nestedTuple = type.dyn_cast<TupleType>()) 787 nestedTuple.getFlattenedTypes(types); 788 else 789 types.push_back(type); 790 } 791 } 792 793 /// Return the number of element types. 794 size_t TupleType::size() const { return getImpl()->size(); } 795 796 //===----------------------------------------------------------------------===// 797 // Type Utilities 798 //===----------------------------------------------------------------------===// 799 800 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 801 int64_t offset, 802 MLIRContext *context) { 803 AffineExpr expr; 804 unsigned nSymbols = 0; 805 806 // AffineExpr for offset. 807 // Static case. 808 if (offset != MemRefType::getDynamicStrideOrOffset()) { 809 auto cst = getAffineConstantExpr(offset, context); 810 expr = cst; 811 } else { 812 // Dynamic case, new symbol for the offset. 813 auto sym = getAffineSymbolExpr(nSymbols++, context); 814 expr = sym; 815 } 816 817 // AffineExpr for strides. 818 for (auto en : llvm::enumerate(strides)) { 819 auto dim = en.index(); 820 auto stride = en.value(); 821 assert(stride != 0 && "Invalid stride specification"); 822 auto d = getAffineDimExpr(dim, context); 823 AffineExpr mult; 824 // Static case. 825 if (stride != MemRefType::getDynamicStrideOrOffset()) 826 mult = getAffineConstantExpr(stride, context); 827 else 828 // Dynamic case, new symbol for each new stride. 829 mult = getAffineSymbolExpr(nSymbols++, context); 830 expr = expr + d * mult; 831 } 832 833 return AffineMap::get(strides.size(), nSymbols, expr); 834 } 835 836 /// Return a version of `t` with identity layout if it can be determined 837 /// statically that the layout is the canonical contiguous strided layout. 838 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 839 /// `t` with simplified layout. 840 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 841 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 842 auto affineMaps = t.getAffineMaps(); 843 // Already in canonical form. 844 if (affineMaps.empty()) 845 return t; 846 847 // Can't reduce to canonical identity form, return in canonical form. 848 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 849 return t; 850 851 // Corner-case for 0-D affine maps. 852 auto m = affineMaps[0]; 853 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 854 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 855 if (cst.getValue() == 0) 856 return MemRefType::Builder(t).setAffineMaps({}); 857 return t; 858 } 859 860 // 0-D corner case for empty shape that still have an affine map. Example: 861 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 862 // offset needs to remain, just return t. 863 if (t.getShape().empty()) 864 return t; 865 866 // If the canonical strided layout for the sizes of `t` is equal to the 867 // simplified layout of `t` we can just return an empty layout. Otherwise, 868 // just simplify the existing layout. 869 AffineExpr expr = 870 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 871 auto simplifiedLayoutExpr = 872 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 873 if (expr != simplifiedLayoutExpr) 874 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 875 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 876 return MemRefType::Builder(t).setAffineMaps({}); 877 } 878 879 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 880 ArrayRef<AffineExpr> exprs, 881 MLIRContext *context) { 882 assert(!sizes.empty() && !exprs.empty() && 883 "expected non-empty sizes and exprs"); 884 885 // Size 0 corner case is useful for canonicalizations. 886 if (llvm::is_contained(sizes, 0)) 887 return getAffineConstantExpr(0, context); 888 889 auto maps = AffineMap::inferFromExprList(exprs); 890 assert(!maps.empty() && "Expected one non-empty map"); 891 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 892 893 AffineExpr expr; 894 bool dynamicPoisonBit = false; 895 int64_t runningSize = 1; 896 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 897 int64_t size = std::get<1>(en); 898 // Degenerate case, no size =-> no stride 899 if (size == 0) 900 continue; 901 AffineExpr dimExpr = std::get<0>(en); 902 AffineExpr stride = dynamicPoisonBit 903 ? getAffineSymbolExpr(nSymbols++, context) 904 : getAffineConstantExpr(runningSize, context); 905 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 906 if (size > 0) { 907 runningSize *= size; 908 assert(runningSize > 0 && "integer overflow in size computation"); 909 } else { 910 dynamicPoisonBit = true; 911 } 912 } 913 return simplifyAffineExpr(expr, numDims, nSymbols); 914 } 915 916 /// Return a version of `t` with a layout that has all dynamic offset and 917 /// strides. This is used to erase the static layout. 918 MemRefType mlir::eraseStridedLayout(MemRefType t) { 919 auto val = ShapedType::kDynamicStrideOrOffset; 920 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 921 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 922 } 923 924 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 925 MLIRContext *context) { 926 SmallVector<AffineExpr, 4> exprs; 927 exprs.reserve(sizes.size()); 928 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 929 exprs.push_back(getAffineDimExpr(dim, context)); 930 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 931 } 932 933 /// Return true if the layout for `t` is compatible with strided semantics. 934 bool mlir::isStrided(MemRefType t) { 935 int64_t offset; 936 SmallVector<int64_t, 4> strides; 937 auto res = getStridesAndOffset(t, strides, offset); 938 return succeeded(res); 939 } 940 941 /// Return the layout map in strided linear layout AffineMap form. 942 /// Return null if the layout is not compatible with a strided layout. 943 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 944 int64_t offset; 945 SmallVector<int64_t, 4> strides; 946 if (failed(getStridesAndOffset(t, strides, offset))) 947 return AffineMap(); 948 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 949 } 950