1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 /// ComplexType 25 //===----------------------------------------------------------------------===// 26 27 ComplexType ComplexType::get(Type elementType) { 28 return Base::get(elementType.getContext(), elementType); 29 } 30 31 ComplexType ComplexType::getChecked(Type elementType, Location location) { 32 return Base::getChecked(location, elementType); 33 } 34 35 /// Verify the construction of an integer type. 36 LogicalResult ComplexType::verifyConstructionInvariants(Location loc, 37 Type elementType) { 38 if (!elementType.isIntOrFloat()) 39 return emitError(loc, "invalid element type for complex"); 40 return success(); 41 } 42 43 Type ComplexType::getElementType() { return getImpl()->elementType; } 44 45 //===----------------------------------------------------------------------===// 46 // Integer Type 47 //===----------------------------------------------------------------------===// 48 49 // static constexpr must have a definition (until in C++17 and inline variable). 50 constexpr unsigned IntegerType::kMaxWidth; 51 52 /// Verify the construction of an integer type. 53 LogicalResult 54 IntegerType::verifyConstructionInvariants(Location loc, unsigned width, 55 SignednessSemantics signedness) { 56 if (width > IntegerType::kMaxWidth) { 57 return emitError(loc) << "integer bitwidth is limited to " 58 << IntegerType::kMaxWidth << " bits"; 59 } 60 return success(); 61 } 62 63 unsigned IntegerType::getWidth() const { return getImpl()->width; } 64 65 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 66 return getImpl()->signedness; 67 } 68 69 IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { 70 if (!scale) 71 return IntegerType(); 72 return IntegerType::get(scale * getWidth(), getSignedness(), getContext()); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // Float Type 77 //===----------------------------------------------------------------------===// 78 79 unsigned FloatType::getWidth() { 80 if (isa<Float16Type, BFloat16Type>()) 81 return 16; 82 if (isa<Float32Type>()) 83 return 32; 84 if (isa<Float64Type>()) 85 return 64; 86 llvm_unreachable("unexpected float type"); 87 } 88 89 /// Returns the floating semantics for the given type. 90 const llvm::fltSemantics &FloatType::getFloatSemantics() { 91 if (isa<BFloat16Type>()) 92 return APFloat::BFloat(); 93 if (isa<Float16Type>()) 94 return APFloat::IEEEhalf(); 95 if (isa<Float32Type>()) 96 return APFloat::IEEEsingle(); 97 if (isa<Float64Type>()) 98 return APFloat::IEEEdouble(); 99 llvm_unreachable("non-floating point type used"); 100 } 101 102 FloatType FloatType::scaleElementBitwidth(unsigned scale) { 103 if (!scale) 104 return FloatType(); 105 MLIRContext *ctx = getContext(); 106 if (isF16() || isBF16()) { 107 if (scale == 2) 108 return FloatType::getF32(ctx); 109 if (scale == 4) 110 return FloatType::getF64(ctx); 111 } 112 if (isF32()) 113 if (scale == 2) 114 return FloatType::getF64(ctx); 115 return FloatType(); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // FunctionType 120 //===----------------------------------------------------------------------===// 121 122 FunctionType FunctionType::get(TypeRange inputs, TypeRange results, 123 MLIRContext *context) { 124 return Base::get(context, inputs, results); 125 } 126 127 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 128 129 ArrayRef<Type> FunctionType::getInputs() const { 130 return getImpl()->getInputs(); 131 } 132 133 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 134 135 ArrayRef<Type> FunctionType::getResults() const { 136 return getImpl()->getResults(); 137 } 138 139 /// Helper to call a callback once on each index in the range 140 /// [0, `totalIndices`), *except* for the indices given in `indices`. 141 /// `indices` is allowed to have duplicates and can be in any order. 142 inline void iterateIndicesExcept(unsigned totalIndices, 143 ArrayRef<unsigned> indices, 144 function_ref<void(unsigned)> callback) { 145 llvm::BitVector skipIndices(totalIndices); 146 for (unsigned i : indices) 147 skipIndices.set(i); 148 149 for (unsigned i = 0; i < totalIndices; ++i) 150 if (!skipIndices.test(i)) 151 callback(i); 152 } 153 154 /// Returns a new function type without the specified arguments and results. 155 FunctionType 156 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 157 ArrayRef<unsigned> resultIndices) { 158 ArrayRef<Type> newInputTypes = getInputs(); 159 SmallVector<Type, 4> newInputTypesBuffer; 160 if (!argIndices.empty()) { 161 unsigned originalNumArgs = getNumInputs(); 162 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 163 newInputTypesBuffer.emplace_back(getInput(i)); 164 }); 165 newInputTypes = newInputTypesBuffer; 166 } 167 168 ArrayRef<Type> newResultTypes = getResults(); 169 SmallVector<Type, 4> newResultTypesBuffer; 170 if (!resultIndices.empty()) { 171 unsigned originalNumResults = getNumResults(); 172 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 173 newResultTypesBuffer.emplace_back(getResult(i)); 174 }); 175 newResultTypes = newResultTypesBuffer; 176 } 177 178 return get(newInputTypes, newResultTypes, getContext()); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // OpaqueType 183 //===----------------------------------------------------------------------===// 184 185 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData, 186 MLIRContext *context) { 187 return Base::get(context, dialect, typeData); 188 } 189 190 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, 191 MLIRContext *context, Location location) { 192 return Base::getChecked(location, dialect, typeData); 193 } 194 195 /// Returns the dialect namespace of the opaque type. 196 Identifier OpaqueType::getDialectNamespace() const { 197 return getImpl()->dialectNamespace; 198 } 199 200 /// Returns the raw type data of the opaque type. 201 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; } 202 203 /// Verify the construction of an opaque type. 204 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, 205 Identifier dialect, 206 StringRef typeData) { 207 if (!Dialect::isValidNamespace(dialect.strref())) 208 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 209 return success(); 210 } 211 212 //===----------------------------------------------------------------------===// 213 // ShapedType 214 //===----------------------------------------------------------------------===// 215 constexpr int64_t ShapedType::kDynamicSize; 216 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 217 218 Type ShapedType::getElementType() const { 219 return static_cast<ImplType *>(impl)->elementType; 220 } 221 222 unsigned ShapedType::getElementTypeBitWidth() const { 223 return getElementType().getIntOrFloatBitWidth(); 224 } 225 226 int64_t ShapedType::getNumElements() const { 227 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 228 auto shape = getShape(); 229 int64_t num = 1; 230 for (auto dim : shape) 231 num *= dim; 232 return num; 233 } 234 235 int64_t ShapedType::getRank() const { return getShape().size(); } 236 237 bool ShapedType::hasRank() const { 238 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 239 } 240 241 int64_t ShapedType::getDimSize(unsigned idx) const { 242 assert(idx < getRank() && "invalid index for shaped type"); 243 return getShape()[idx]; 244 } 245 246 bool ShapedType::isDynamicDim(unsigned idx) const { 247 assert(idx < getRank() && "invalid index for shaped type"); 248 return isDynamic(getShape()[idx]); 249 } 250 251 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 252 assert(index < getRank() && "invalid index"); 253 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 254 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 255 } 256 257 /// Get the number of bits require to store a value of the given shaped type. 258 /// Compute the value recursively since tensors are allowed to have vectors as 259 /// elements. 260 int64_t ShapedType::getSizeInBits() const { 261 assert(hasStaticShape() && 262 "cannot get the bit size of an aggregate with a dynamic shape"); 263 264 auto elementType = getElementType(); 265 if (elementType.isIntOrFloat()) 266 return elementType.getIntOrFloatBitWidth() * getNumElements(); 267 268 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 269 elementType = complexType.getElementType(); 270 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 271 } 272 273 // Tensors can have vectors and other tensors as elements, other shaped types 274 // cannot. 275 assert(isa<TensorType>() && "unsupported element type"); 276 assert((elementType.isa<VectorType, TensorType>()) && 277 "unsupported tensor element type"); 278 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 279 } 280 281 ArrayRef<int64_t> ShapedType::getShape() const { 282 if (auto vectorType = dyn_cast<VectorType>()) 283 return vectorType.getShape(); 284 if (auto tensorType = dyn_cast<RankedTensorType>()) 285 return tensorType.getShape(); 286 return cast<MemRefType>().getShape(); 287 } 288 289 int64_t ShapedType::getNumDynamicDims() const { 290 return llvm::count_if(getShape(), isDynamic); 291 } 292 293 bool ShapedType::hasStaticShape() const { 294 return hasRank() && llvm::none_of(getShape(), isDynamic); 295 } 296 297 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 298 return hasStaticShape() && getShape() == shape; 299 } 300 301 //===----------------------------------------------------------------------===// 302 // VectorType 303 //===----------------------------------------------------------------------===// 304 305 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { 306 return Base::get(elementType.getContext(), shape, elementType); 307 } 308 309 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, 310 Location location) { 311 return Base::getChecked(location, shape, elementType); 312 } 313 314 LogicalResult VectorType::verifyConstructionInvariants(Location loc, 315 ArrayRef<int64_t> shape, 316 Type elementType) { 317 if (shape.empty()) 318 return emitError(loc, "vector types must have at least one dimension"); 319 320 if (!isValidElementType(elementType)) 321 return emitError(loc, "vector elements must be int or float type"); 322 323 if (any_of(shape, [](int64_t i) { return i <= 0; })) 324 return emitError(loc, "vector types must have positive constant sizes"); 325 326 return success(); 327 } 328 329 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } 330 331 VectorType VectorType::scaleElementBitwidth(unsigned scale) { 332 if (!scale) 333 return VectorType(); 334 if (auto et = getElementType().dyn_cast<IntegerType>()) 335 if (auto scaledEt = et.scaleElementBitwidth(scale)) 336 return VectorType::get(getShape(), scaledEt); 337 if (auto et = getElementType().dyn_cast<FloatType>()) 338 if (auto scaledEt = et.scaleElementBitwidth(scale)) 339 return VectorType::get(getShape(), scaledEt); 340 return VectorType(); 341 } 342 343 //===----------------------------------------------------------------------===// 344 // TensorType 345 //===----------------------------------------------------------------------===// 346 347 // Check if "elementType" can be an element type of a tensor. Emit errors if 348 // location is not nullptr. Returns failure if check failed. 349 static LogicalResult checkTensorElementType(Location location, 350 Type elementType) { 351 if (!TensorType::isValidElementType(elementType)) 352 return emitError(location, "invalid tensor element type: ") << elementType; 353 return success(); 354 } 355 356 /// Return true if the specified element type is ok in a tensor. 357 bool TensorType::isValidElementType(Type type) { 358 // Note: Non standard/builtin types are allowed to exist within tensor 359 // types. Dialects are expected to verify that tensor types have a valid 360 // element type within that dialect. 361 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 362 IndexType>() || 363 !type.getDialect().getNamespace().empty(); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // RankedTensorType 368 //===----------------------------------------------------------------------===// 369 370 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, 371 Type elementType) { 372 return Base::get(elementType.getContext(), shape, elementType); 373 } 374 375 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, 376 Type elementType, 377 Location location) { 378 return Base::getChecked(location, shape, elementType); 379 } 380 381 LogicalResult RankedTensorType::verifyConstructionInvariants( 382 Location loc, ArrayRef<int64_t> shape, Type elementType) { 383 for (int64_t s : shape) { 384 if (s < -1) 385 return emitError(loc, "invalid tensor dimension size"); 386 } 387 return checkTensorElementType(loc, elementType); 388 } 389 390 ArrayRef<int64_t> RankedTensorType::getShape() const { 391 return getImpl()->getShape(); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // UnrankedTensorType 396 //===----------------------------------------------------------------------===// 397 398 UnrankedTensorType UnrankedTensorType::get(Type elementType) { 399 return Base::get(elementType.getContext(), elementType); 400 } 401 402 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, 403 Location location) { 404 return Base::getChecked(location, elementType); 405 } 406 407 LogicalResult 408 UnrankedTensorType::verifyConstructionInvariants(Location loc, 409 Type elementType) { 410 return checkTensorElementType(loc, elementType); 411 } 412 413 //===----------------------------------------------------------------------===// 414 // BaseMemRefType 415 //===----------------------------------------------------------------------===// 416 417 unsigned BaseMemRefType::getMemorySpace() const { 418 return static_cast<ImplType *>(impl)->memorySpace; 419 } 420 421 //===----------------------------------------------------------------------===// 422 // MemRefType 423 //===----------------------------------------------------------------------===// 424 425 /// Get or create a new MemRefType based on shape, element type, affine 426 /// map composition, and memory space. Assumes the arguments define a 427 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 428 /// construction failures. 429 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 430 ArrayRef<AffineMap> affineMapComposition, 431 unsigned memorySpace) { 432 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, 433 /*location=*/llvm::None); 434 assert(result && "Failed to construct instance of MemRefType."); 435 return result; 436 } 437 438 /// Get or create a new MemRefType based on shape, element type, affine 439 /// map composition, and memory space declared at the given location. 440 /// If the location is unknown, the last argument should be an instance of 441 /// UnknownLoc. If the MemRefType defined by the arguments would be 442 /// ill-formed, emits errors (to the handler registered with the context or to 443 /// the error stream) and returns nullptr. 444 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType, 445 ArrayRef<AffineMap> affineMapComposition, 446 unsigned memorySpace, Location location) { 447 return getImpl(shape, elementType, affineMapComposition, memorySpace, 448 location); 449 } 450 451 /// Get or create a new MemRefType defined by the arguments. If the resulting 452 /// type would be ill-formed, return nullptr. If the location is provided, 453 /// emit detailed error messages. To emit errors when the location is unknown, 454 /// pass in an instance of UnknownLoc. 455 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, 456 ArrayRef<AffineMap> affineMapComposition, 457 unsigned memorySpace, 458 Optional<Location> location) { 459 auto *context = elementType.getContext(); 460 461 if (!BaseMemRefType::isValidElementType(elementType)) 462 return emitOptionalError(location, "invalid memref element type"), 463 MemRefType(); 464 465 for (int64_t s : shape) { 466 // Negative sizes are not allowed except for `-1` that means dynamic size. 467 if (s < -1) 468 return emitOptionalError(location, "invalid memref size"), MemRefType(); 469 } 470 471 // Check that the structure of the composition is valid, i.e. that each 472 // subsequent affine map has as many inputs as the previous map has results. 473 // Take the dimensionality of the MemRef for the first map. 474 auto dim = shape.size(); 475 unsigned i = 0; 476 for (const auto &affineMap : affineMapComposition) { 477 if (affineMap.getNumDims() != dim) { 478 if (location) 479 emitError(*location) 480 << "memref affine map dimension mismatch between " 481 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 482 << " and affine map" << i + 1 << ": " << dim 483 << " != " << affineMap.getNumDims(); 484 return nullptr; 485 } 486 487 dim = affineMap.getNumResults(); 488 ++i; 489 } 490 491 // Drop identity maps from the composition. 492 // This may lead to the composition becoming empty, which is interpreted as an 493 // implicit identity. 494 SmallVector<AffineMap, 2> cleanedAffineMapComposition; 495 for (const auto &map : affineMapComposition) { 496 if (map.isIdentity()) 497 continue; 498 cleanedAffineMapComposition.push_back(map); 499 } 500 501 return Base::get(context, shape, elementType, cleanedAffineMapComposition, 502 memorySpace); 503 } 504 505 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } 506 507 ArrayRef<AffineMap> MemRefType::getAffineMaps() const { 508 return getImpl()->getAffineMaps(); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // UnrankedMemRefType 513 //===----------------------------------------------------------------------===// 514 515 UnrankedMemRefType UnrankedMemRefType::get(Type elementType, 516 unsigned memorySpace) { 517 return Base::get(elementType.getContext(), elementType, memorySpace); 518 } 519 520 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, 521 unsigned memorySpace, 522 Location location) { 523 return Base::getChecked(location, elementType, memorySpace); 524 } 525 526 LogicalResult 527 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, 528 unsigned memorySpace) { 529 if (!BaseMemRefType::isValidElementType(elementType)) 530 return emitError(loc, "invalid memref element type"); 531 return success(); 532 } 533 534 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 535 // i.e. single term). Accumulate the AffineExpr into the existing one. 536 static void extractStridesFromTerm(AffineExpr e, 537 AffineExpr multiplicativeFactor, 538 MutableArrayRef<AffineExpr> strides, 539 AffineExpr &offset) { 540 if (auto dim = e.dyn_cast<AffineDimExpr>()) 541 strides[dim.getPosition()] = 542 strides[dim.getPosition()] + multiplicativeFactor; 543 else 544 offset = offset + e * multiplicativeFactor; 545 } 546 547 /// Takes a single AffineExpr `e` and populates the `strides` array with the 548 /// strides expressions for each dim position. 549 /// The convention is that the strides for dimensions d0, .. dn appear in 550 /// order to make indexing intuitive into the result. 551 static LogicalResult extractStrides(AffineExpr e, 552 AffineExpr multiplicativeFactor, 553 MutableArrayRef<AffineExpr> strides, 554 AffineExpr &offset) { 555 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 556 if (!bin) { 557 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 558 return success(); 559 } 560 561 if (bin.getKind() == AffineExprKind::CeilDiv || 562 bin.getKind() == AffineExprKind::FloorDiv || 563 bin.getKind() == AffineExprKind::Mod) 564 return failure(); 565 566 if (bin.getKind() == AffineExprKind::Mul) { 567 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 568 if (dim) { 569 strides[dim.getPosition()] = 570 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 571 return success(); 572 } 573 // LHS and RHS may both contain complex expressions of dims. Try one path 574 // and if it fails try the other. This is guaranteed to succeed because 575 // only one path may have a `dim`, otherwise this is not an AffineExpr in 576 // the first place. 577 if (bin.getLHS().isSymbolicOrConstant()) 578 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 579 strides, offset); 580 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 581 strides, offset); 582 } 583 584 if (bin.getKind() == AffineExprKind::Add) { 585 auto res1 = 586 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 587 auto res2 = 588 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 589 return success(succeeded(res1) && succeeded(res2)); 590 } 591 592 llvm_unreachable("unexpected binary operation"); 593 } 594 595 LogicalResult mlir::getStridesAndOffset(MemRefType t, 596 SmallVectorImpl<AffineExpr> &strides, 597 AffineExpr &offset) { 598 auto affineMaps = t.getAffineMaps(); 599 // For now strides are only computed on a single affine map with a single 600 // result (i.e. the closed subset of linearization maps that are compatible 601 // with striding semantics). 602 // TODO: support more forms on a per-need basis. 603 if (affineMaps.size() > 1) 604 return failure(); 605 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 606 return failure(); 607 608 auto zero = getAffineConstantExpr(0, t.getContext()); 609 auto one = getAffineConstantExpr(1, t.getContext()); 610 offset = zero; 611 strides.assign(t.getRank(), zero); 612 613 AffineMap m; 614 if (!affineMaps.empty()) { 615 m = affineMaps.front(); 616 assert(!m.isIdentity() && "unexpected identity map"); 617 } 618 619 // Canonical case for empty map. 620 if (!m) { 621 // 0-D corner case, offset is already 0. 622 if (t.getRank() == 0) 623 return success(); 624 auto stridedExpr = 625 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 626 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 627 return success(); 628 assert(false && "unexpected failure: extract strides in canonical layout"); 629 } 630 631 // Non-canonical case requires more work. 632 auto stridedExpr = 633 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 634 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 635 offset = AffineExpr(); 636 strides.clear(); 637 return failure(); 638 } 639 640 // Simplify results to allow folding to constants and simple checks. 641 unsigned numDims = m.getNumDims(); 642 unsigned numSymbols = m.getNumSymbols(); 643 offset = simplifyAffineExpr(offset, numDims, numSymbols); 644 for (auto &stride : strides) 645 stride = simplifyAffineExpr(stride, numDims, numSymbols); 646 647 /// In practice, a strided memref must be internally non-aliasing. Test 648 /// against 0 as a proxy. 649 /// TODO: static cases can have more advanced checks. 650 /// TODO: dynamic cases would require a way to compare symbolic 651 /// expressions and would probably need an affine set context propagated 652 /// everywhere. 653 if (llvm::any_of(strides, [](AffineExpr e) { 654 return e == getAffineConstantExpr(0, e.getContext()); 655 })) { 656 offset = AffineExpr(); 657 strides.clear(); 658 return failure(); 659 } 660 661 return success(); 662 } 663 664 LogicalResult mlir::getStridesAndOffset(MemRefType t, 665 SmallVectorImpl<int64_t> &strides, 666 int64_t &offset) { 667 AffineExpr offsetExpr; 668 SmallVector<AffineExpr, 4> strideExprs; 669 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 670 return failure(); 671 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 672 offset = cst.getValue(); 673 else 674 offset = ShapedType::kDynamicStrideOrOffset; 675 for (auto e : strideExprs) { 676 if (auto c = e.dyn_cast<AffineConstantExpr>()) 677 strides.push_back(c.getValue()); 678 else 679 strides.push_back(ShapedType::kDynamicStrideOrOffset); 680 } 681 return success(); 682 } 683 684 //===----------------------------------------------------------------------===// 685 /// TupleType 686 //===----------------------------------------------------------------------===// 687 688 /// Get or create a new TupleType with the provided element types. Assumes the 689 /// arguments define a well-formed type. 690 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) { 691 return Base::get(context, elementTypes); 692 } 693 694 /// Get or create an empty tuple type. 695 TupleType TupleType::get(MLIRContext *context) { return get({}, context); } 696 697 /// Return the elements types for this tuple. 698 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 699 700 /// Accumulate the types contained in this tuple and tuples nested within it. 701 /// Note that this only flattens nested tuples, not any other container type, 702 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 703 /// (i32, tensor<i32>, f32, i64) 704 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 705 for (Type type : getTypes()) { 706 if (auto nestedTuple = type.dyn_cast<TupleType>()) 707 nestedTuple.getFlattenedTypes(types); 708 else 709 types.push_back(type); 710 } 711 } 712 713 /// Return the number of element types. 714 size_t TupleType::size() const { return getImpl()->size(); } 715 716 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 717 int64_t offset, 718 MLIRContext *context) { 719 AffineExpr expr; 720 unsigned nSymbols = 0; 721 722 // AffineExpr for offset. 723 // Static case. 724 if (offset != MemRefType::getDynamicStrideOrOffset()) { 725 auto cst = getAffineConstantExpr(offset, context); 726 expr = cst; 727 } else { 728 // Dynamic case, new symbol for the offset. 729 auto sym = getAffineSymbolExpr(nSymbols++, context); 730 expr = sym; 731 } 732 733 // AffineExpr for strides. 734 for (auto en : llvm::enumerate(strides)) { 735 auto dim = en.index(); 736 auto stride = en.value(); 737 assert(stride != 0 && "Invalid stride specification"); 738 auto d = getAffineDimExpr(dim, context); 739 AffineExpr mult; 740 // Static case. 741 if (stride != MemRefType::getDynamicStrideOrOffset()) 742 mult = getAffineConstantExpr(stride, context); 743 else 744 // Dynamic case, new symbol for each new stride. 745 mult = getAffineSymbolExpr(nSymbols++, context); 746 expr = expr + d * mult; 747 } 748 749 return AffineMap::get(strides.size(), nSymbols, expr); 750 } 751 752 /// Return a version of `t` with identity layout if it can be determined 753 /// statically that the layout is the canonical contiguous strided layout. 754 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 755 /// `t` with simplified layout. 756 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 757 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 758 auto affineMaps = t.getAffineMaps(); 759 // Already in canonical form. 760 if (affineMaps.empty()) 761 return t; 762 763 // Can't reduce to canonical identity form, return in canonical form. 764 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 765 return t; 766 767 // If the canonical strided layout for the sizes of `t` is equal to the 768 // simplified layout of `t` we can just return an empty layout. Otherwise, 769 // just simplify the existing layout. 770 AffineExpr expr = 771 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 772 auto m = affineMaps[0]; 773 auto simplifiedLayoutExpr = 774 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 775 if (expr != simplifiedLayoutExpr) 776 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 777 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 778 return MemRefType::Builder(t).setAffineMaps({}); 779 } 780 781 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 782 ArrayRef<AffineExpr> exprs, 783 MLIRContext *context) { 784 // Size 0 corner case is useful for canonicalizations. 785 if (llvm::is_contained(sizes, 0)) 786 return getAffineConstantExpr(0, context); 787 788 auto maps = AffineMap::inferFromExprList(exprs); 789 assert(!maps.empty() && "Expected one non-empty map"); 790 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 791 792 AffineExpr expr; 793 bool dynamicPoisonBit = false; 794 int64_t runningSize = 1; 795 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 796 int64_t size = std::get<1>(en); 797 // Degenerate case, no size =-> no stride 798 if (size == 0) 799 continue; 800 AffineExpr dimExpr = std::get<0>(en); 801 AffineExpr stride = dynamicPoisonBit 802 ? getAffineSymbolExpr(nSymbols++, context) 803 : getAffineConstantExpr(runningSize, context); 804 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 805 if (size > 0) 806 runningSize *= size; 807 else 808 dynamicPoisonBit = true; 809 } 810 return simplifyAffineExpr(expr, numDims, nSymbols); 811 } 812 813 /// Return a version of `t` with a layout that has all dynamic offset and 814 /// strides. This is used to erase the static layout. 815 MemRefType mlir::eraseStridedLayout(MemRefType t) { 816 auto val = ShapedType::kDynamicStrideOrOffset; 817 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 818 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 819 } 820 821 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 822 MLIRContext *context) { 823 SmallVector<AffineExpr, 4> exprs; 824 exprs.reserve(sizes.size()); 825 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 826 exprs.push_back(getAffineDimExpr(dim, context)); 827 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 828 } 829 830 /// Return true if the layout for `t` is compatible with strided semantics. 831 bool mlir::isStrided(MemRefType t) { 832 int64_t offset; 833 SmallVector<int64_t, 4> stridesAndOffset; 834 auto res = getStridesAndOffset(t, stridesAndOffset, offset); 835 return succeeded(res); 836 } 837