1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 /// Tablegen Type Definitions 25 //===----------------------------------------------------------------------===// 26 27 #define GET_TYPEDEF_CLASSES 28 #include "mlir/IR/BuiltinTypes.cpp.inc" 29 30 //===----------------------------------------------------------------------===// 31 /// ComplexType 32 //===----------------------------------------------------------------------===// 33 34 /// Verify the construction of an integer type. 35 LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, 36 Type elementType) { 37 if (!elementType.isIntOrFloat()) 38 return emitError() << "invalid element type for complex"; 39 return success(); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // Integer Type 44 //===----------------------------------------------------------------------===// 45 46 // static constexpr must have a definition (until in C++17 and inline variable). 47 constexpr unsigned IntegerType::kMaxWidth; 48 49 /// Verify the construction of an integer type. 50 LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, 51 unsigned width, 52 SignednessSemantics signedness) { 53 if (width > IntegerType::kMaxWidth) { 54 return emitError() << "integer bitwidth is limited to " 55 << IntegerType::kMaxWidth << " bits"; 56 } 57 return success(); 58 } 59 60 unsigned IntegerType::getWidth() const { return getImpl()->width; } 61 62 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 63 return getImpl()->signedness; 64 } 65 66 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 67 if (!scale) 68 return IntegerType(); 69 return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); 70 } 71 72 //===----------------------------------------------------------------------===// 73 // Float Type 74 //===----------------------------------------------------------------------===// 75 76 unsigned FloatType::getWidth() { 77 if (isa<Float16Type, BFloat16Type>()) 78 return 16; 79 if (isa<Float32Type>()) 80 return 32; 81 if (isa<Float64Type>()) 82 return 64; 83 if (isa<Float80Type>()) 84 return 80; 85 if (isa<Float128Type>()) 86 return 128; 87 llvm_unreachable("unexpected float type"); 88 } 89 90 /// Returns the floating semantics for the given type. 91 const llvm::fltSemantics &FloatType::getFloatSemantics() { 92 if (isa<BFloat16Type>()) 93 return APFloat::BFloat(); 94 if (isa<Float16Type>()) 95 return APFloat::IEEEhalf(); 96 if (isa<Float32Type>()) 97 return APFloat::IEEEsingle(); 98 if (isa<Float64Type>()) 99 return APFloat::IEEEdouble(); 100 if (isa<Float80Type>()) 101 return APFloat::x87DoubleExtended(); 102 if (isa<Float128Type>()) 103 return APFloat::IEEEquad(); 104 llvm_unreachable("non-floating point type used"); 105 } 106 107 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 108 if (!scale) 109 return FloatType(); 110 MLIRContext *ctx = getContext(); 111 if (isF16() || isBF16()) { 112 if (scale == 2) 113 return FloatType::getF32(ctx); 114 if (scale == 4) 115 return FloatType::getF64(ctx); 116 } 117 if (isF32()) 118 if (scale == 2) 119 return FloatType::getF64(ctx); 120 return FloatType(); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // FunctionType 125 //===----------------------------------------------------------------------===// 126 127 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 128 129 ArrayRef<Type> FunctionType::getInputs() const { 130 return getImpl()->getInputs(); 131 } 132 133 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 134 135 ArrayRef<Type> FunctionType::getResults() const { 136 return getImpl()->getResults(); 137 } 138 139 /// Helper to call a callback once on each index in the range 140 /// [0, `totalIndices`), *except* for the indices given in `indices`. 141 /// `indices` is allowed to have duplicates and can be in any order. 142 inline void iterateIndicesExcept(unsigned totalIndices, 143 ArrayRef<unsigned> indices, 144 function_ref<void(unsigned)> callback) { 145 llvm::BitVector skipIndices(totalIndices); 146 for (unsigned i : indices) 147 skipIndices.set(i); 148 149 for (unsigned i = 0; i < totalIndices; ++i) 150 if (!skipIndices.test(i)) 151 callback(i); 152 } 153 154 /// Returns a new function type without the specified arguments and results. 155 FunctionType 156 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 157 ArrayRef<unsigned> resultIndices) { 158 ArrayRef<Type> newInputTypes = getInputs(); 159 SmallVector<Type, 4> newInputTypesBuffer; 160 if (!argIndices.empty()) { 161 unsigned originalNumArgs = getNumInputs(); 162 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 163 newInputTypesBuffer.emplace_back(getInput(i)); 164 }); 165 newInputTypes = newInputTypesBuffer; 166 } 167 168 ArrayRef<Type> newResultTypes = getResults(); 169 SmallVector<Type, 4> newResultTypesBuffer; 170 if (!resultIndices.empty()) { 171 unsigned originalNumResults = getNumResults(); 172 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 173 newResultTypesBuffer.emplace_back(getResult(i)); 174 }); 175 newResultTypes = newResultTypesBuffer; 176 } 177 178 return get(getContext(), newInputTypes, newResultTypes); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // OpaqueType 183 //===----------------------------------------------------------------------===// 184 185 /// Verify the construction of an opaque type. 186 LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, 187 Identifier dialect, StringRef typeData) { 188 if (!Dialect::isValidNamespace(dialect.strref())) 189 return emitError() << "invalid dialect namespace '" << dialect << "'"; 190 return success(); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // ShapedType 195 //===----------------------------------------------------------------------===// 196 constexpr int64_t ShapedType::kDynamicSize; 197 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 198 199 ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) { 200 if (auto other = dyn_cast<MemRefType>()) { 201 MemRefType::Builder b(other); 202 b.setShape(shape); 203 b.setElementType(elementType); 204 return b; 205 } 206 207 if (auto other = dyn_cast<UnrankedMemRefType>()) { 208 MemRefType::Builder b(shape, elementType); 209 b.setMemorySpace(other.getMemorySpace()); 210 return b; 211 } 212 213 if (isa<TensorType>()) 214 return RankedTensorType::get(shape, elementType); 215 216 if (isa<VectorType>()) 217 return VectorType::get(shape, elementType); 218 219 llvm_unreachable("Unhandled ShapedType clone case"); 220 } 221 222 ShapedType ShapedType::clone(ArrayRef<int64_t> shape) { 223 if (auto other = dyn_cast<MemRefType>()) { 224 MemRefType::Builder b(other); 225 b.setShape(shape); 226 return b; 227 } 228 229 if (auto other = dyn_cast<UnrankedMemRefType>()) { 230 MemRefType::Builder b(shape, other.getElementType()); 231 b.setShape(shape); 232 b.setMemorySpace(other.getMemorySpace()); 233 return b; 234 } 235 236 if (isa<TensorType>()) 237 return RankedTensorType::get(shape, getElementType()); 238 239 if (isa<VectorType>()) 240 return VectorType::get(shape, getElementType()); 241 242 llvm_unreachable("Unhandled ShapedType clone case"); 243 } 244 245 ShapedType ShapedType::clone(Type elementType) { 246 if (auto other = dyn_cast<MemRefType>()) { 247 MemRefType::Builder b(other); 248 b.setElementType(elementType); 249 return b; 250 } 251 252 if (auto other = dyn_cast<UnrankedMemRefType>()) { 253 return UnrankedMemRefType::get(elementType, other.getMemorySpace()); 254 } 255 256 if (isa<TensorType>()) { 257 if (hasRank()) 258 return RankedTensorType::get(getShape(), elementType); 259 return UnrankedTensorType::get(elementType); 260 } 261 262 if (isa<VectorType>()) 263 return VectorType::get(getShape(), elementType); 264 265 llvm_unreachable("Unhandled ShapedType clone hit"); 266 } 267 268 Type ShapedType::getElementType() const { 269 return static_cast<ImplType *>(impl)->elementType; 270 } 271 272 unsigned ShapedType::getElementTypeBitWidth() const { 273 return getElementType().getIntOrFloatBitWidth(); 274 } 275 276 int64_t ShapedType::getNumElements() const { 277 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 278 auto shape = getShape(); 279 int64_t num = 1; 280 for (auto dim : shape) { 281 num *= dim; 282 assert(num >= 0 && "integer overflow in element count computation"); 283 } 284 return num; 285 } 286 287 int64_t ShapedType::getRank() const { 288 assert(hasRank() && "cannot query rank of unranked shaped type"); 289 return getShape().size(); 290 } 291 292 bool ShapedType::hasRank() const { 293 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 294 } 295 296 int64_t ShapedType::getDimSize(unsigned idx) const { 297 assert(idx < getRank() && "invalid index for shaped type"); 298 return getShape()[idx]; 299 } 300 301 bool ShapedType::isDynamicDim(unsigned idx) const { 302 assert(idx < getRank() && "invalid index for shaped type"); 303 return isDynamic(getShape()[idx]); 304 } 305 306 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 307 assert(index < getRank() && "invalid index"); 308 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 309 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 310 } 311 312 /// Get the number of bits require to store a value of the given shaped type. 313 /// Compute the value recursively since tensors are allowed to have vectors as 314 /// elements. 315 int64_t ShapedType::getSizeInBits() const { 316 assert(hasStaticShape() && 317 "cannot get the bit size of an aggregate with a dynamic shape"); 318 319 auto elementType = getElementType(); 320 if (elementType.isIntOrFloat()) 321 return elementType.getIntOrFloatBitWidth() * getNumElements(); 322 323 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 324 elementType = complexType.getElementType(); 325 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 326 } 327 328 // Tensors can have vectors and other tensors as elements, other shaped types 329 // cannot. 330 assert(isa<TensorType>() && "unsupported element type"); 331 assert((elementType.isa<VectorType, TensorType>()) && 332 "unsupported tensor element type"); 333 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 334 } 335 336 ArrayRef<int64_t> ShapedType::getShape() const { 337 if (auto vectorType = dyn_cast<VectorType>()) 338 return vectorType.getShape(); 339 if (auto tensorType = dyn_cast<RankedTensorType>()) 340 return tensorType.getShape(); 341 return cast<MemRefType>().getShape(); 342 } 343 344 int64_t ShapedType::getNumDynamicDims() const { 345 return llvm::count_if(getShape(), isDynamic); 346 } 347 348 bool ShapedType::hasStaticShape() const { 349 return hasRank() && llvm::none_of(getShape(), isDynamic); 350 } 351 352 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 353 return hasStaticShape() && getShape() == shape; 354 } 355 356 //===----------------------------------------------------------------------===// 357 // VectorType 358 //===----------------------------------------------------------------------===// 359 360 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { 361 return Base::get(elementType.getContext(), shape, elementType); 362 } 363 364 VectorType VectorType::getChecked(function_ref<InFlightDiagnostic()> emitError, 365 ArrayRef<int64_t> shape, Type elementType) { 366 return Base::getChecked(emitError, elementType.getContext(), shape, 367 elementType); 368 } 369 370 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, 371 ArrayRef<int64_t> shape, Type elementType) { 372 if (shape.empty()) 373 return emitError() << "vector types must have at least one dimension"; 374 375 if (!isValidElementType(elementType)) 376 return emitError() << "vector elements must be int or float type"; 377 378 if (any_of(shape, [](int64_t i) { return i <= 0; })) 379 return emitError() << "vector types must have positive constant sizes"; 380 381 return success(); 382 } 383 384 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } 385 386 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 387 if (!scale) 388 return VectorType(); 389 if (auto et = getElementType().dyn_cast<IntegerType>()) 390 if (auto scaledEt = et.scaleElementBitwidth(scale)) 391 return VectorType::get(getShape(), scaledEt); 392 if (auto et = getElementType().dyn_cast<FloatType>()) 393 if (auto scaledEt = et.scaleElementBitwidth(scale)) 394 return VectorType::get(getShape(), scaledEt); 395 return VectorType(); 396 } 397 398 //===----------------------------------------------------------------------===// 399 // TensorType 400 //===----------------------------------------------------------------------===// 401 402 // Check if "elementType" can be an element type of a tensor. 403 static LogicalResult 404 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, 405 Type elementType) { 406 if (!TensorType::isValidElementType(elementType)) 407 return emitError() << "invalid tensor element type: " << elementType; 408 return success(); 409 } 410 411 /// Return true if the specified element type is ok in a tensor. 412 bool TensorType::isValidElementType(Type type) { 413 // Note: Non standard/builtin types are allowed to exist within tensor 414 // types. Dialects are expected to verify that tensor types have a valid 415 // element type within that dialect. 416 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 417 IndexType>() || 418 !type.getDialect().getNamespace().empty(); 419 } 420 421 //===----------------------------------------------------------------------===// 422 // RankedTensorType 423 //===----------------------------------------------------------------------===// 424 425 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, 426 Type elementType) { 427 return Base::get(elementType.getContext(), shape, elementType); 428 } 429 430 RankedTensorType 431 RankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError, 432 ArrayRef<int64_t> shape, Type elementType) { 433 return Base::getChecked(emitError, elementType.getContext(), shape, 434 elementType); 435 } 436 437 LogicalResult 438 RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 439 ArrayRef<int64_t> shape, Type elementType) { 440 for (int64_t s : shape) { 441 if (s < -1) 442 return emitError() << "invalid tensor dimension size"; 443 } 444 return checkTensorElementType(emitError, elementType); 445 } 446 447 ArrayRef<int64_t> RankedTensorType::getShape() const { 448 return getImpl()->getShape(); 449 } 450 451 //===----------------------------------------------------------------------===// 452 // UnrankedTensorType 453 //===----------------------------------------------------------------------===// 454 455 UnrankedTensorType UnrankedTensorType::get(Type elementType) { 456 return Base::get(elementType.getContext(), elementType); 457 } 458 459 UnrankedTensorType 460 UnrankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError, 461 Type elementType) { 462 return Base::getChecked(emitError, elementType.getContext(), elementType); 463 } 464 465 LogicalResult 466 UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, 467 Type elementType) { 468 return checkTensorElementType(emitError, elementType); 469 } 470 471 //===----------------------------------------------------------------------===// 472 // BaseMemRefType 473 //===----------------------------------------------------------------------===// 474 475 unsigned BaseMemRefType::getMemorySpace() const { 476 return static_cast<ImplType *>(impl)->memorySpace; 477 } 478 479 //===----------------------------------------------------------------------===// 480 // MemRefType 481 //===----------------------------------------------------------------------===// 482 483 /// Get or create a new MemRefType based on shape, element type, affine 484 /// map composition, and memory space. Assumes the arguments define a 485 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 486 /// construction failures. 487 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 488 ArrayRef<AffineMap> affineMapComposition, 489 unsigned memorySpace) { 490 auto result = 491 getImpl(shape, elementType, affineMapComposition, memorySpace, [=] { 492 return emitError(UnknownLoc::get(elementType.getContext())); 493 }); 494 assert(result && "Failed to construct instance of MemRefType."); 495 return result; 496 } 497 498 /// Get or create a new MemRefType based on shape, element type, affine 499 /// map composition, and memory space declared at the given location. 500 /// If the location is unknown, the last argument should be an instance of 501 /// UnknownLoc. If the MemRefType defined by the arguments would be 502 /// ill-formed, emits errors (to the handler registered with the context or to 503 /// the error stream) and returns nullptr. 504 MemRefType MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError, 505 ArrayRef<int64_t> shape, Type elementType, 506 ArrayRef<AffineMap> affineMapComposition, 507 unsigned memorySpace) { 508 return getImpl(shape, elementType, affineMapComposition, memorySpace, 509 emitError); 510 } 511 512 /// Get or create a new MemRefType defined by the arguments. If the resulting 513 /// type would be ill-formed, return nullptr. If the location is provided, 514 /// emit detailed error messages. To emit errors when the location is unknown, 515 /// pass in an instance of UnknownLoc. 516 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, 517 ArrayRef<AffineMap> affineMapComposition, 518 unsigned memorySpace, 519 function_ref<InFlightDiagnostic()> emitError) { 520 auto *context = elementType.getContext(); 521 522 if (!BaseMemRefType::isValidElementType(elementType)) 523 return (emitError() << "invalid memref element type", MemRefType()); 524 525 for (int64_t s : shape) { 526 // Negative sizes are not allowed except for `-1` that means dynamic size. 527 if (s < -1) 528 return (emitError() << "invalid memref size", MemRefType()); 529 } 530 531 // Check that the structure of the composition is valid, i.e. that each 532 // subsequent affine map has as many inputs as the previous map has results. 533 // Take the dimensionality of the MemRef for the first map. 534 auto dim = shape.size(); 535 unsigned i = 0; 536 for (const auto &affineMap : affineMapComposition) { 537 if (affineMap.getNumDims() != dim) { 538 emitError() << "memref affine map dimension mismatch between " 539 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 540 << " and affine map" << i + 1 << ": " << dim 541 << " != " << affineMap.getNumDims(); 542 return nullptr; 543 } 544 545 dim = affineMap.getNumResults(); 546 ++i; 547 } 548 549 // Drop identity maps from the composition. 550 // This may lead to the composition becoming empty, which is interpreted as an 551 // implicit identity. 552 SmallVector<AffineMap, 2> cleanedAffineMapComposition; 553 for (const auto &map : affineMapComposition) { 554 if (map.isIdentity()) 555 continue; 556 cleanedAffineMapComposition.push_back(map); 557 } 558 559 return Base::get(context, shape, elementType, cleanedAffineMapComposition, 560 memorySpace); 561 } 562 563 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } 564 565 ArrayRef<AffineMap> MemRefType::getAffineMaps() const { 566 return getImpl()->getAffineMaps(); 567 } 568 569 //===----------------------------------------------------------------------===// 570 // UnrankedMemRefType 571 //===----------------------------------------------------------------------===// 572 573 UnrankedMemRefType UnrankedMemRefType::get(Type elementType, 574 unsigned memorySpace) { 575 return Base::get(elementType.getContext(), elementType, memorySpace); 576 } 577 578 UnrankedMemRefType 579 UnrankedMemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError, 580 Type elementType, unsigned memorySpace) { 581 return Base::getChecked(emitError, elementType.getContext(), elementType, 582 memorySpace); 583 } 584 585 LogicalResult 586 UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, 587 Type elementType, unsigned memorySpace) { 588 if (!BaseMemRefType::isValidElementType(elementType)) 589 return emitError() << "invalid memref element type"; 590 return success(); 591 } 592 593 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 594 // i.e. single term). Accumulate the AffineExpr into the existing one. 595 static void extractStridesFromTerm(AffineExpr e, 596 AffineExpr multiplicativeFactor, 597 MutableArrayRef<AffineExpr> strides, 598 AffineExpr &offset) { 599 if (auto dim = e.dyn_cast<AffineDimExpr>()) 600 strides[dim.getPosition()] = 601 strides[dim.getPosition()] + multiplicativeFactor; 602 else 603 offset = offset + e * multiplicativeFactor; 604 } 605 606 /// Takes a single AffineExpr `e` and populates the `strides` array with the 607 /// strides expressions for each dim position. 608 /// The convention is that the strides for dimensions d0, .. dn appear in 609 /// order to make indexing intuitive into the result. 610 static LogicalResult extractStrides(AffineExpr e, 611 AffineExpr multiplicativeFactor, 612 MutableArrayRef<AffineExpr> strides, 613 AffineExpr &offset) { 614 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 615 if (!bin) { 616 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 617 return success(); 618 } 619 620 if (bin.getKind() == AffineExprKind::CeilDiv || 621 bin.getKind() == AffineExprKind::FloorDiv || 622 bin.getKind() == AffineExprKind::Mod) 623 return failure(); 624 625 if (bin.getKind() == AffineExprKind::Mul) { 626 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 627 if (dim) { 628 strides[dim.getPosition()] = 629 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 630 return success(); 631 } 632 // LHS and RHS may both contain complex expressions of dims. Try one path 633 // and if it fails try the other. This is guaranteed to succeed because 634 // only one path may have a `dim`, otherwise this is not an AffineExpr in 635 // the first place. 636 if (bin.getLHS().isSymbolicOrConstant()) 637 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 638 strides, offset); 639 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 640 strides, offset); 641 } 642 643 if (bin.getKind() == AffineExprKind::Add) { 644 auto res1 = 645 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 646 auto res2 = 647 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 648 return success(succeeded(res1) && succeeded(res2)); 649 } 650 651 llvm_unreachable("unexpected binary operation"); 652 } 653 654 LogicalResult mlir::getStridesAndOffset(MemRefType t, 655 SmallVectorImpl<AffineExpr> &strides, 656 AffineExpr &offset) { 657 auto affineMaps = t.getAffineMaps(); 658 // For now strides are only computed on a single affine map with a single 659 // result (i.e. the closed subset of linearization maps that are compatible 660 // with striding semantics). 661 // TODO: support more forms on a per-need basis. 662 if (affineMaps.size() > 1) 663 return failure(); 664 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 665 return failure(); 666 667 auto zero = getAffineConstantExpr(0, t.getContext()); 668 auto one = getAffineConstantExpr(1, t.getContext()); 669 offset = zero; 670 strides.assign(t.getRank(), zero); 671 672 AffineMap m; 673 if (!affineMaps.empty()) { 674 m = affineMaps.front(); 675 assert(!m.isIdentity() && "unexpected identity map"); 676 } 677 678 // Canonical case for empty map. 679 if (!m) { 680 // 0-D corner case, offset is already 0. 681 if (t.getRank() == 0) 682 return success(); 683 auto stridedExpr = 684 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 685 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 686 return success(); 687 assert(false && "unexpected failure: extract strides in canonical layout"); 688 } 689 690 // Non-canonical case requires more work. 691 auto stridedExpr = 692 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 693 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 694 offset = AffineExpr(); 695 strides.clear(); 696 return failure(); 697 } 698 699 // Simplify results to allow folding to constants and simple checks. 700 unsigned numDims = m.getNumDims(); 701 unsigned numSymbols = m.getNumSymbols(); 702 offset = simplifyAffineExpr(offset, numDims, numSymbols); 703 for (auto &stride : strides) 704 stride = simplifyAffineExpr(stride, numDims, numSymbols); 705 706 /// In practice, a strided memref must be internally non-aliasing. Test 707 /// against 0 as a proxy. 708 /// TODO: static cases can have more advanced checks. 709 /// TODO: dynamic cases would require a way to compare symbolic 710 /// expressions and would probably need an affine set context propagated 711 /// everywhere. 712 if (llvm::any_of(strides, [](AffineExpr e) { 713 return e == getAffineConstantExpr(0, e.getContext()); 714 })) { 715 offset = AffineExpr(); 716 strides.clear(); 717 return failure(); 718 } 719 720 return success(); 721 } 722 723 LogicalResult mlir::getStridesAndOffset(MemRefType t, 724 SmallVectorImpl<int64_t> &strides, 725 int64_t &offset) { 726 AffineExpr offsetExpr; 727 SmallVector<AffineExpr, 4> strideExprs; 728 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 729 return failure(); 730 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 731 offset = cst.getValue(); 732 else 733 offset = ShapedType::kDynamicStrideOrOffset; 734 for (auto e : strideExprs) { 735 if (auto c = e.dyn_cast<AffineConstantExpr>()) 736 strides.push_back(c.getValue()); 737 else 738 strides.push_back(ShapedType::kDynamicStrideOrOffset); 739 } 740 return success(); 741 } 742 743 //===----------------------------------------------------------------------===// 744 /// TupleType 745 //===----------------------------------------------------------------------===// 746 747 /// Return the elements types for this tuple. 748 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 749 750 /// Accumulate the types contained in this tuple and tuples nested within it. 751 /// Note that this only flattens nested tuples, not any other container type, 752 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 753 /// (i32, tensor<i32>, f32, i64) 754 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 755 for (Type type : getTypes()) { 756 if (auto nestedTuple = type.dyn_cast<TupleType>()) 757 nestedTuple.getFlattenedTypes(types); 758 else 759 types.push_back(type); 760 } 761 } 762 763 /// Return the number of element types. 764 size_t TupleType::size() const { return getImpl()->size(); } 765 766 //===----------------------------------------------------------------------===// 767 // Type Utilities 768 //===----------------------------------------------------------------------===// 769 770 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 771 int64_t offset, 772 MLIRContext *context) { 773 AffineExpr expr; 774 unsigned nSymbols = 0; 775 776 // AffineExpr for offset. 777 // Static case. 778 if (offset != MemRefType::getDynamicStrideOrOffset()) { 779 auto cst = getAffineConstantExpr(offset, context); 780 expr = cst; 781 } else { 782 // Dynamic case, new symbol for the offset. 783 auto sym = getAffineSymbolExpr(nSymbols++, context); 784 expr = sym; 785 } 786 787 // AffineExpr for strides. 788 for (auto en : llvm::enumerate(strides)) { 789 auto dim = en.index(); 790 auto stride = en.value(); 791 assert(stride != 0 && "Invalid stride specification"); 792 auto d = getAffineDimExpr(dim, context); 793 AffineExpr mult; 794 // Static case. 795 if (stride != MemRefType::getDynamicStrideOrOffset()) 796 mult = getAffineConstantExpr(stride, context); 797 else 798 // Dynamic case, new symbol for each new stride. 799 mult = getAffineSymbolExpr(nSymbols++, context); 800 expr = expr + d * mult; 801 } 802 803 return AffineMap::get(strides.size(), nSymbols, expr); 804 } 805 806 /// Return a version of `t` with identity layout if it can be determined 807 /// statically that the layout is the canonical contiguous strided layout. 808 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 809 /// `t` with simplified layout. 810 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 811 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 812 auto affineMaps = t.getAffineMaps(); 813 // Already in canonical form. 814 if (affineMaps.empty()) 815 return t; 816 817 // Can't reduce to canonical identity form, return in canonical form. 818 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 819 return t; 820 821 // Corner-case for 0-D affine maps. 822 auto m = affineMaps[0]; 823 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { 824 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) 825 if (cst.getValue() == 0) 826 return MemRefType::Builder(t).setAffineMaps({}); 827 return t; 828 } 829 830 // 0-D corner case for empty shape that still have an affine map. Example: 831 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose 832 // offset needs to remain, just return t. 833 if (t.getShape().empty()) 834 return t; 835 836 // If the canonical strided layout for the sizes of `t` is equal to the 837 // simplified layout of `t` we can just return an empty layout. Otherwise, 838 // just simplify the existing layout. 839 AffineExpr expr = 840 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 841 auto simplifiedLayoutExpr = 842 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 843 if (expr != simplifiedLayoutExpr) 844 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 845 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 846 return MemRefType::Builder(t).setAffineMaps({}); 847 } 848 849 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 850 ArrayRef<AffineExpr> exprs, 851 MLIRContext *context) { 852 assert(!sizes.empty() && !exprs.empty() && 853 "expected non-empty sizes and exprs"); 854 855 // Size 0 corner case is useful for canonicalizations. 856 if (llvm::is_contained(sizes, 0)) 857 return getAffineConstantExpr(0, context); 858 859 auto maps = AffineMap::inferFromExprList(exprs); 860 assert(!maps.empty() && "Expected one non-empty map"); 861 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 862 863 AffineExpr expr; 864 bool dynamicPoisonBit = false; 865 int64_t runningSize = 1; 866 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 867 int64_t size = std::get<1>(en); 868 // Degenerate case, no size =-> no stride 869 if (size == 0) 870 continue; 871 AffineExpr dimExpr = std::get<0>(en); 872 AffineExpr stride = dynamicPoisonBit 873 ? getAffineSymbolExpr(nSymbols++, context) 874 : getAffineConstantExpr(runningSize, context); 875 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 876 if (size > 0) { 877 runningSize *= size; 878 assert(runningSize > 0 && "integer overflow in size computation"); 879 } else { 880 dynamicPoisonBit = true; 881 } 882 } 883 return simplifyAffineExpr(expr, numDims, nSymbols); 884 } 885 886 /// Return a version of `t` with a layout that has all dynamic offset and 887 /// strides. This is used to erase the static layout. 888 MemRefType mlir::eraseStridedLayout(MemRefType t) { 889 auto val = ShapedType::kDynamicStrideOrOffset; 890 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 891 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 892 } 893 894 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 895 MLIRContext *context) { 896 SmallVector<AffineExpr, 4> exprs; 897 exprs.reserve(sizes.size()); 898 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 899 exprs.push_back(getAffineDimExpr(dim, context)); 900 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 901 } 902 903 /// Return true if the layout for `t` is compatible with strided semantics. 904 bool mlir::isStrided(MemRefType t) { 905 int64_t offset; 906 SmallVector<int64_t, 4> strides; 907 auto res = getStridesAndOffset(t, strides, offset); 908 return succeeded(res); 909 } 910 911 /// Return the layout map in strided linear layout AffineMap form. 912 /// Return null if the layout is not compatible with a strided layout. 913 AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { 914 int64_t offset; 915 SmallVector<int64_t, 4> strides; 916 if (failed(getStridesAndOffset(t, strides, offset))) 917 return AffineMap(); 918 return makeStridedLinearLayoutMap(strides, offset, t.getContext()); 919 } 920