1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 /// Tablegen Type Definitions 25 //===----------------------------------------------------------------------===// 26 27 #define GET_TYPEDEF_CLASSES 28 #include "mlir/IR/BuiltinTypes.cpp.inc" 29 30 //===----------------------------------------------------------------------===// 31 /// ComplexType 32 //===----------------------------------------------------------------------===// 33 34 ComplexType ComplexType::get(Type elementType) { 35 return Base::get(elementType.getContext(), elementType); 36 } 37 38 ComplexType ComplexType::getChecked(Type elementType, Location location) { 39 return Base::getChecked(location, elementType); 40 } 41 42 /// Verify the construction of an integer type. 43 LogicalResult ComplexType::verifyConstructionInvariants(Location loc, 44 Type elementType) { 45 if (!elementType.isIntOrFloat()) 46 return emitError(loc, "invalid element type for complex"); 47 return success(); 48 } 49 50 Type ComplexType::getElementType() { return getImpl()->elementType; } 51 52 //===----------------------------------------------------------------------===// 53 // Integer Type 54 //===----------------------------------------------------------------------===// 55 56 // static constexpr must have a definition (until in C++17 and inline variable). 57 constexpr unsigned IntegerType::kMaxWidth; 58 59 /// Verify the construction of an integer type. 60 LogicalResult 61 IntegerType::verifyConstructionInvariants(Location loc, unsigned width, 62 SignednessSemantics signedness) { 63 if (width > IntegerType::kMaxWidth) { 64 return emitError(loc) << "integer bitwidth is limited to " 65 << IntegerType::kMaxWidth << " bits"; 66 } 67 return success(); 68 } 69 70 unsigned IntegerType::getWidth() const { return getImpl()->width; } 71 72 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 73 return getImpl()->signedness; 74 } 75 76 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 77 if (!scale) 78 return IntegerType(); 79 return IntegerType::get(scale * getWidth(), getSignedness(), getContext()); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // Float Type 84 //===----------------------------------------------------------------------===// 85 86 unsigned FloatType::getWidth() { 87 if (isa<Float16Type, BFloat16Type>()) 88 return 16; 89 if (isa<Float32Type>()) 90 return 32; 91 if (isa<Float64Type>()) 92 return 64; 93 llvm_unreachable("unexpected float type"); 94 } 95 96 /// Returns the floating semantics for the given type. 97 const llvm::fltSemantics &FloatType::getFloatSemantics() { 98 if (isa<BFloat16Type>()) 99 return APFloat::BFloat(); 100 if (isa<Float16Type>()) 101 return APFloat::IEEEhalf(); 102 if (isa<Float32Type>()) 103 return APFloat::IEEEsingle(); 104 if (isa<Float64Type>()) 105 return APFloat::IEEEdouble(); 106 llvm_unreachable("non-floating point type used"); 107 } 108 109 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 110 if (!scale) 111 return FloatType(); 112 MLIRContext *ctx = getContext(); 113 if (isF16() || isBF16()) { 114 if (scale == 2) 115 return FloatType::getF32(ctx); 116 if (scale == 4) 117 return FloatType::getF64(ctx); 118 } 119 if (isF32()) 120 if (scale == 2) 121 return FloatType::getF64(ctx); 122 return FloatType(); 123 } 124 125 //===----------------------------------------------------------------------===// 126 // FunctionType 127 //===----------------------------------------------------------------------===// 128 129 FunctionType FunctionType::get(TypeRange inputs, TypeRange results, 130 MLIRContext *context) { 131 return Base::get(context, inputs, results); 132 } 133 134 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 135 136 ArrayRef<Type> FunctionType::getInputs() const { 137 return getImpl()->getInputs(); 138 } 139 140 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 141 142 ArrayRef<Type> FunctionType::getResults() const { 143 return getImpl()->getResults(); 144 } 145 146 /// Helper to call a callback once on each index in the range 147 /// [0, `totalIndices`), *except* for the indices given in `indices`. 148 /// `indices` is allowed to have duplicates and can be in any order. 149 inline void iterateIndicesExcept(unsigned totalIndices, 150 ArrayRef<unsigned> indices, 151 function_ref<void(unsigned)> callback) { 152 llvm::BitVector skipIndices(totalIndices); 153 for (unsigned i : indices) 154 skipIndices.set(i); 155 156 for (unsigned i = 0; i < totalIndices; ++i) 157 if (!skipIndices.test(i)) 158 callback(i); 159 } 160 161 /// Returns a new function type without the specified arguments and results. 162 FunctionType 163 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 164 ArrayRef<unsigned> resultIndices) { 165 ArrayRef<Type> newInputTypes = getInputs(); 166 SmallVector<Type, 4> newInputTypesBuffer; 167 if (!argIndices.empty()) { 168 unsigned originalNumArgs = getNumInputs(); 169 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 170 newInputTypesBuffer.emplace_back(getInput(i)); 171 }); 172 newInputTypes = newInputTypesBuffer; 173 } 174 175 ArrayRef<Type> newResultTypes = getResults(); 176 SmallVector<Type, 4> newResultTypesBuffer; 177 if (!resultIndices.empty()) { 178 unsigned originalNumResults = getNumResults(); 179 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 180 newResultTypesBuffer.emplace_back(getResult(i)); 181 }); 182 newResultTypes = newResultTypesBuffer; 183 } 184 185 return get(newInputTypes, newResultTypes, getContext()); 186 } 187 188 //===----------------------------------------------------------------------===// 189 // OpaqueType 190 //===----------------------------------------------------------------------===// 191 192 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData, 193 MLIRContext *context) { 194 return Base::get(context, dialect, typeData); 195 } 196 197 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, 198 MLIRContext *context, Location location) { 199 return Base::getChecked(location, dialect, typeData); 200 } 201 202 /// Returns the dialect namespace of the opaque type. 203 Identifier OpaqueType::getDialectNamespace() const { 204 return getImpl()->dialectNamespace; 205 } 206 207 /// Returns the raw type data of the opaque type. 208 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; } 209 210 /// Verify the construction of an opaque type. 211 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, 212 Identifier dialect, 213 StringRef typeData) { 214 if (!Dialect::isValidNamespace(dialect.strref())) 215 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 216 return success(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // ShapedType 221 //===----------------------------------------------------------------------===// 222 constexpr int64_t ShapedType::kDynamicSize; 223 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 224 225 Type ShapedType::getElementType() const { 226 return static_cast<ImplType *>(impl)->elementType; 227 } 228 229 unsigned ShapedType::getElementTypeBitWidth() const { 230 return getElementType().getIntOrFloatBitWidth(); 231 } 232 233 int64_t ShapedType::getNumElements() const { 234 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 235 auto shape = getShape(); 236 int64_t num = 1; 237 for (auto dim : shape) 238 num *= dim; 239 return num; 240 } 241 242 int64_t ShapedType::getRank() const { return getShape().size(); } 243 244 bool ShapedType::hasRank() const { 245 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 246 } 247 248 int64_t ShapedType::getDimSize(unsigned idx) const { 249 assert(idx < getRank() && "invalid index for shaped type"); 250 return getShape()[idx]; 251 } 252 253 bool ShapedType::isDynamicDim(unsigned idx) const { 254 assert(idx < getRank() && "invalid index for shaped type"); 255 return isDynamic(getShape()[idx]); 256 } 257 258 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 259 assert(index < getRank() && "invalid index"); 260 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 261 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 262 } 263 264 /// Get the number of bits require to store a value of the given shaped type. 265 /// Compute the value recursively since tensors are allowed to have vectors as 266 /// elements. 267 int64_t ShapedType::getSizeInBits() const { 268 assert(hasStaticShape() && 269 "cannot get the bit size of an aggregate with a dynamic shape"); 270 271 auto elementType = getElementType(); 272 if (elementType.isIntOrFloat()) 273 return elementType.getIntOrFloatBitWidth() * getNumElements(); 274 275 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 276 elementType = complexType.getElementType(); 277 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 278 } 279 280 // Tensors can have vectors and other tensors as elements, other shaped types 281 // cannot. 282 assert(isa<TensorType>() && "unsupported element type"); 283 assert((elementType.isa<VectorType, TensorType>()) && 284 "unsupported tensor element type"); 285 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 286 } 287 288 ArrayRef<int64_t> ShapedType::getShape() const { 289 if (auto vectorType = dyn_cast<VectorType>()) 290 return vectorType.getShape(); 291 if (auto tensorType = dyn_cast<RankedTensorType>()) 292 return tensorType.getShape(); 293 return cast<MemRefType>().getShape(); 294 } 295 296 int64_t ShapedType::getNumDynamicDims() const { 297 return llvm::count_if(getShape(), isDynamic); 298 } 299 300 bool ShapedType::hasStaticShape() const { 301 return hasRank() && llvm::none_of(getShape(), isDynamic); 302 } 303 304 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 305 return hasStaticShape() && getShape() == shape; 306 } 307 308 //===----------------------------------------------------------------------===// 309 // VectorType 310 //===----------------------------------------------------------------------===// 311 312 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { 313 return Base::get(elementType.getContext(), shape, elementType); 314 } 315 316 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, 317 Location location) { 318 return Base::getChecked(location, shape, elementType); 319 } 320 321 LogicalResult VectorType::verifyConstructionInvariants(Location loc, 322 ArrayRef<int64_t> shape, 323 Type elementType) { 324 if (shape.empty()) 325 return emitError(loc, "vector types must have at least one dimension"); 326 327 if (!isValidElementType(elementType)) 328 return emitError(loc, "vector elements must be int or float type"); 329 330 if (any_of(shape, [](int64_t i) { return i <= 0; })) 331 return emitError(loc, "vector types must have positive constant sizes"); 332 333 return success(); 334 } 335 336 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } 337 338 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 339 if (!scale) 340 return VectorType(); 341 if (auto et = getElementType().dyn_cast<IntegerType>()) 342 if (auto scaledEt = et.scaleElementBitwidth(scale)) 343 return VectorType::get(getShape(), scaledEt); 344 if (auto et = getElementType().dyn_cast<FloatType>()) 345 if (auto scaledEt = et.scaleElementBitwidth(scale)) 346 return VectorType::get(getShape(), scaledEt); 347 return VectorType(); 348 } 349 350 //===----------------------------------------------------------------------===// 351 // TensorType 352 //===----------------------------------------------------------------------===// 353 354 // Check if "elementType" can be an element type of a tensor. Emit errors if 355 // location is not nullptr. Returns failure if check failed. 356 static LogicalResult checkTensorElementType(Location location, 357 Type elementType) { 358 if (!TensorType::isValidElementType(elementType)) 359 return emitError(location, "invalid tensor element type: ") << elementType; 360 return success(); 361 } 362 363 /// Return true if the specified element type is ok in a tensor. 364 bool TensorType::isValidElementType(Type type) { 365 // Note: Non standard/builtin types are allowed to exist within tensor 366 // types. Dialects are expected to verify that tensor types have a valid 367 // element type within that dialect. 368 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 369 IndexType>() || 370 !type.getDialect().getNamespace().empty(); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // RankedTensorType 375 //===----------------------------------------------------------------------===// 376 377 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, 378 Type elementType) { 379 return Base::get(elementType.getContext(), shape, elementType); 380 } 381 382 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, 383 Type elementType, 384 Location location) { 385 return Base::getChecked(location, shape, elementType); 386 } 387 388 LogicalResult RankedTensorType::verifyConstructionInvariants( 389 Location loc, ArrayRef<int64_t> shape, Type elementType) { 390 for (int64_t s : shape) { 391 if (s < -1) 392 return emitError(loc, "invalid tensor dimension size"); 393 } 394 return checkTensorElementType(loc, elementType); 395 } 396 397 ArrayRef<int64_t> RankedTensorType::getShape() const { 398 return getImpl()->getShape(); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // UnrankedTensorType 403 //===----------------------------------------------------------------------===// 404 405 UnrankedTensorType UnrankedTensorType::get(Type elementType) { 406 return Base::get(elementType.getContext(), elementType); 407 } 408 409 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, 410 Location location) { 411 return Base::getChecked(location, elementType); 412 } 413 414 LogicalResult 415 UnrankedTensorType::verifyConstructionInvariants(Location loc, 416 Type elementType) { 417 return checkTensorElementType(loc, elementType); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // BaseMemRefType 422 //===----------------------------------------------------------------------===// 423 424 unsigned BaseMemRefType::getMemorySpace() const { 425 return static_cast<ImplType *>(impl)->memorySpace; 426 } 427 428 //===----------------------------------------------------------------------===// 429 // MemRefType 430 //===----------------------------------------------------------------------===// 431 432 /// Get or create a new MemRefType based on shape, element type, affine 433 /// map composition, and memory space. Assumes the arguments define a 434 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 435 /// construction failures. 436 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 437 ArrayRef<AffineMap> affineMapComposition, 438 unsigned memorySpace) { 439 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, 440 /*location=*/llvm::None); 441 assert(result && "Failed to construct instance of MemRefType."); 442 return result; 443 } 444 445 /// Get or create a new MemRefType based on shape, element type, affine 446 /// map composition, and memory space declared at the given location. 447 /// If the location is unknown, the last argument should be an instance of 448 /// UnknownLoc. If the MemRefType defined by the arguments would be 449 /// ill-formed, emits errors (to the handler registered with the context or to 450 /// the error stream) and returns nullptr. 451 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType, 452 ArrayRef<AffineMap> affineMapComposition, 453 unsigned memorySpace, Location location) { 454 return getImpl(shape, elementType, affineMapComposition, memorySpace, 455 location); 456 } 457 458 /// Get or create a new MemRefType defined by the arguments. If the resulting 459 /// type would be ill-formed, return nullptr. If the location is provided, 460 /// emit detailed error messages. To emit errors when the location is unknown, 461 /// pass in an instance of UnknownLoc. 462 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, 463 ArrayRef<AffineMap> affineMapComposition, 464 unsigned memorySpace, 465 Optional<Location> location) { 466 auto *context = elementType.getContext(); 467 468 if (!BaseMemRefType::isValidElementType(elementType)) 469 return emitOptionalError(location, "invalid memref element type"), 470 MemRefType(); 471 472 for (int64_t s : shape) { 473 // Negative sizes are not allowed except for `-1` that means dynamic size. 474 if (s < -1) 475 return emitOptionalError(location, "invalid memref size"), MemRefType(); 476 } 477 478 // Check that the structure of the composition is valid, i.e. that each 479 // subsequent affine map has as many inputs as the previous map has results. 480 // Take the dimensionality of the MemRef for the first map. 481 auto dim = shape.size(); 482 unsigned i = 0; 483 for (const auto &affineMap : affineMapComposition) { 484 if (affineMap.getNumDims() != dim) { 485 if (location) 486 emitError(*location) 487 << "memref affine map dimension mismatch between " 488 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 489 << " and affine map" << i + 1 << ": " << dim 490 << " != " << affineMap.getNumDims(); 491 return nullptr; 492 } 493 494 dim = affineMap.getNumResults(); 495 ++i; 496 } 497 498 // Drop identity maps from the composition. 499 // This may lead to the composition becoming empty, which is interpreted as an 500 // implicit identity. 501 SmallVector<AffineMap, 2> cleanedAffineMapComposition; 502 for (const auto &map : affineMapComposition) { 503 if (map.isIdentity()) 504 continue; 505 cleanedAffineMapComposition.push_back(map); 506 } 507 508 return Base::get(context, shape, elementType, cleanedAffineMapComposition, 509 memorySpace); 510 } 511 512 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } 513 514 ArrayRef<AffineMap> MemRefType::getAffineMaps() const { 515 return getImpl()->getAffineMaps(); 516 } 517 518 //===----------------------------------------------------------------------===// 519 // UnrankedMemRefType 520 //===----------------------------------------------------------------------===// 521 522 UnrankedMemRefType UnrankedMemRefType::get(Type elementType, 523 unsigned memorySpace) { 524 return Base::get(elementType.getContext(), elementType, memorySpace); 525 } 526 527 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, 528 unsigned memorySpace, 529 Location location) { 530 return Base::getChecked(location, elementType, memorySpace); 531 } 532 533 LogicalResult 534 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, 535 unsigned memorySpace) { 536 if (!BaseMemRefType::isValidElementType(elementType)) 537 return emitError(loc, "invalid memref element type"); 538 return success(); 539 } 540 541 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 542 // i.e. single term). Accumulate the AffineExpr into the existing one. 543 static void extractStridesFromTerm(AffineExpr e, 544 AffineExpr multiplicativeFactor, 545 MutableArrayRef<AffineExpr> strides, 546 AffineExpr &offset) { 547 if (auto dim = e.dyn_cast<AffineDimExpr>()) 548 strides[dim.getPosition()] = 549 strides[dim.getPosition()] + multiplicativeFactor; 550 else 551 offset = offset + e * multiplicativeFactor; 552 } 553 554 /// Takes a single AffineExpr `e` and populates the `strides` array with the 555 /// strides expressions for each dim position. 556 /// The convention is that the strides for dimensions d0, .. dn appear in 557 /// order to make indexing intuitive into the result. 558 static LogicalResult extractStrides(AffineExpr e, 559 AffineExpr multiplicativeFactor, 560 MutableArrayRef<AffineExpr> strides, 561 AffineExpr &offset) { 562 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 563 if (!bin) { 564 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 565 return success(); 566 } 567 568 if (bin.getKind() == AffineExprKind::CeilDiv || 569 bin.getKind() == AffineExprKind::FloorDiv || 570 bin.getKind() == AffineExprKind::Mod) 571 return failure(); 572 573 if (bin.getKind() == AffineExprKind::Mul) { 574 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 575 if (dim) { 576 strides[dim.getPosition()] = 577 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 578 return success(); 579 } 580 // LHS and RHS may both contain complex expressions of dims. Try one path 581 // and if it fails try the other. This is guaranteed to succeed because 582 // only one path may have a `dim`, otherwise this is not an AffineExpr in 583 // the first place. 584 if (bin.getLHS().isSymbolicOrConstant()) 585 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 586 strides, offset); 587 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 588 strides, offset); 589 } 590 591 if (bin.getKind() == AffineExprKind::Add) { 592 auto res1 = 593 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 594 auto res2 = 595 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 596 return success(succeeded(res1) && succeeded(res2)); 597 } 598 599 llvm_unreachable("unexpected binary operation"); 600 } 601 602 LogicalResult mlir::getStridesAndOffset(MemRefType t, 603 SmallVectorImpl<AffineExpr> &strides, 604 AffineExpr &offset) { 605 auto affineMaps = t.getAffineMaps(); 606 // For now strides are only computed on a single affine map with a single 607 // result (i.e. the closed subset of linearization maps that are compatible 608 // with striding semantics). 609 // TODO: support more forms on a per-need basis. 610 if (affineMaps.size() > 1) 611 return failure(); 612 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 613 return failure(); 614 615 auto zero = getAffineConstantExpr(0, t.getContext()); 616 auto one = getAffineConstantExpr(1, t.getContext()); 617 offset = zero; 618 strides.assign(t.getRank(), zero); 619 620 AffineMap m; 621 if (!affineMaps.empty()) { 622 m = affineMaps.front(); 623 assert(!m.isIdentity() && "unexpected identity map"); 624 } 625 626 // Canonical case for empty map. 627 if (!m) { 628 // 0-D corner case, offset is already 0. 629 if (t.getRank() == 0) 630 return success(); 631 auto stridedExpr = 632 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 633 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 634 return success(); 635 assert(false && "unexpected failure: extract strides in canonical layout"); 636 } 637 638 // Non-canonical case requires more work. 639 auto stridedExpr = 640 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 641 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 642 offset = AffineExpr(); 643 strides.clear(); 644 return failure(); 645 } 646 647 // Simplify results to allow folding to constants and simple checks. 648 unsigned numDims = m.getNumDims(); 649 unsigned numSymbols = m.getNumSymbols(); 650 offset = simplifyAffineExpr(offset, numDims, numSymbols); 651 for (auto &stride : strides) 652 stride = simplifyAffineExpr(stride, numDims, numSymbols); 653 654 /// In practice, a strided memref must be internally non-aliasing. Test 655 /// against 0 as a proxy. 656 /// TODO: static cases can have more advanced checks. 657 /// TODO: dynamic cases would require a way to compare symbolic 658 /// expressions and would probably need an affine set context propagated 659 /// everywhere. 660 if (llvm::any_of(strides, [](AffineExpr e) { 661 return e == getAffineConstantExpr(0, e.getContext()); 662 })) { 663 offset = AffineExpr(); 664 strides.clear(); 665 return failure(); 666 } 667 668 return success(); 669 } 670 671 LogicalResult mlir::getStridesAndOffset(MemRefType t, 672 SmallVectorImpl<int64_t> &strides, 673 int64_t &offset) { 674 AffineExpr offsetExpr; 675 SmallVector<AffineExpr, 4> strideExprs; 676 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 677 return failure(); 678 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 679 offset = cst.getValue(); 680 else 681 offset = ShapedType::kDynamicStrideOrOffset; 682 for (auto e : strideExprs) { 683 if (auto c = e.dyn_cast<AffineConstantExpr>()) 684 strides.push_back(c.getValue()); 685 else 686 strides.push_back(ShapedType::kDynamicStrideOrOffset); 687 } 688 return success(); 689 } 690 691 //===----------------------------------------------------------------------===// 692 /// TupleType 693 //===----------------------------------------------------------------------===// 694 695 /// Get or create a new TupleType with the provided element types. Assumes the 696 /// arguments define a well-formed type. 697 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) { 698 return Base::get(context, elementTypes); 699 } 700 701 /// Get or create an empty tuple type. 702 TupleType TupleType::get(MLIRContext *context) { return get({}, context); } 703 704 /// Return the elements types for this tuple. 705 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 706 707 /// Accumulate the types contained in this tuple and tuples nested within it. 708 /// Note that this only flattens nested tuples, not any other container type, 709 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 710 /// (i32, tensor<i32>, f32, i64) 711 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 712 for (Type type : getTypes()) { 713 if (auto nestedTuple = type.dyn_cast<TupleType>()) 714 nestedTuple.getFlattenedTypes(types); 715 else 716 types.push_back(type); 717 } 718 } 719 720 /// Return the number of element types. 721 size_t TupleType::size() const { return getImpl()->size(); } 722 723 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 724 int64_t offset, 725 MLIRContext *context) { 726 AffineExpr expr; 727 unsigned nSymbols = 0; 728 729 // AffineExpr for offset. 730 // Static case. 731 if (offset != MemRefType::getDynamicStrideOrOffset()) { 732 auto cst = getAffineConstantExpr(offset, context); 733 expr = cst; 734 } else { 735 // Dynamic case, new symbol for the offset. 736 auto sym = getAffineSymbolExpr(nSymbols++, context); 737 expr = sym; 738 } 739 740 // AffineExpr for strides. 741 for (auto en : llvm::enumerate(strides)) { 742 auto dim = en.index(); 743 auto stride = en.value(); 744 assert(stride != 0 && "Invalid stride specification"); 745 auto d = getAffineDimExpr(dim, context); 746 AffineExpr mult; 747 // Static case. 748 if (stride != MemRefType::getDynamicStrideOrOffset()) 749 mult = getAffineConstantExpr(stride, context); 750 else 751 // Dynamic case, new symbol for each new stride. 752 mult = getAffineSymbolExpr(nSymbols++, context); 753 expr = expr + d * mult; 754 } 755 756 return AffineMap::get(strides.size(), nSymbols, expr); 757 } 758 759 /// Return a version of `t` with identity layout if it can be determined 760 /// statically that the layout is the canonical contiguous strided layout. 761 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 762 /// `t` with simplified layout. 763 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 764 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 765 auto affineMaps = t.getAffineMaps(); 766 // Already in canonical form. 767 if (affineMaps.empty()) 768 return t; 769 770 // Can't reduce to canonical identity form, return in canonical form. 771 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 772 return t; 773 774 // If the canonical strided layout for the sizes of `t` is equal to the 775 // simplified layout of `t` we can just return an empty layout. Otherwise, 776 // just simplify the existing layout. 777 AffineExpr expr = 778 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 779 auto m = affineMaps[0]; 780 auto simplifiedLayoutExpr = 781 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 782 if (expr != simplifiedLayoutExpr) 783 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 784 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 785 return MemRefType::Builder(t).setAffineMaps({}); 786 } 787 788 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 789 ArrayRef<AffineExpr> exprs, 790 MLIRContext *context) { 791 // Size 0 corner case is useful for canonicalizations. 792 if (llvm::is_contained(sizes, 0)) 793 return getAffineConstantExpr(0, context); 794 795 auto maps = AffineMap::inferFromExprList(exprs); 796 assert(!maps.empty() && "Expected one non-empty map"); 797 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 798 799 AffineExpr expr; 800 bool dynamicPoisonBit = false; 801 int64_t runningSize = 1; 802 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 803 int64_t size = std::get<1>(en); 804 // Degenerate case, no size =-> no stride 805 if (size == 0) 806 continue; 807 AffineExpr dimExpr = std::get<0>(en); 808 AffineExpr stride = dynamicPoisonBit 809 ? getAffineSymbolExpr(nSymbols++, context) 810 : getAffineConstantExpr(runningSize, context); 811 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 812 if (size > 0) 813 runningSize *= size; 814 else 815 dynamicPoisonBit = true; 816 } 817 return simplifyAffineExpr(expr, numDims, nSymbols); 818 } 819 820 /// Return a version of `t` with a layout that has all dynamic offset and 821 /// strides. This is used to erase the static layout. 822 MemRefType mlir::eraseStridedLayout(MemRefType t) { 823 auto val = ShapedType::kDynamicStrideOrOffset; 824 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 825 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 826 } 827 828 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 829 MLIRContext *context) { 830 SmallVector<AffineExpr, 4> exprs; 831 exprs.reserve(sizes.size()); 832 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 833 exprs.push_back(getAffineDimExpr(dim, context)); 834 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 835 } 836 837 /// Return true if the layout for `t` is compatible with strided semantics. 838 bool mlir::isStrided(MemRefType t) { 839 int64_t offset; 840 SmallVector<int64_t, 4> stridesAndOffset; 841 auto res = getStridesAndOffset(t, stridesAndOffset, offset); 842 return succeeded(res); 843 } 844