1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/AffineMap.h" 13 #include "mlir/IR/Diagnostics.h" 14 #include "mlir/IR/Dialect.h" 15 #include "llvm/ADT/APFloat.h" 16 #include "llvm/ADT/BitVector.h" 17 #include "llvm/ADT/Sequence.h" 18 #include "llvm/ADT/Twine.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 //===----------------------------------------------------------------------===// 24 /// ComplexType 25 //===----------------------------------------------------------------------===// 26 27 ComplexType ComplexType::get(Type elementType) { 28 return Base::get(elementType.getContext(), elementType); 29 } 30 31 ComplexType ComplexType::getChecked(Type elementType, Location location) { 32 return Base::getChecked(location, elementType); 33 } 34 35 /// Verify the construction of an integer type. 36 LogicalResult ComplexType::verifyConstructionInvariants(Location loc, 37 Type elementType) { 38 if (!elementType.isIntOrFloat()) 39 return emitError(loc, "invalid element type for complex"); 40 return success(); 41 } 42 43 Type ComplexType::getElementType() { return getImpl()->elementType; } 44 45 //===----------------------------------------------------------------------===// 46 // Integer Type 47 //===----------------------------------------------------------------------===// 48 49 // static constexpr must have a definition (until in C++17 and inline variable). 50 constexpr unsigned IntegerType::kMaxWidth; 51 52 /// Verify the construction of an integer type. 53 LogicalResult 54 IntegerType::verifyConstructionInvariants(Location loc, unsigned width, 55 SignednessSemantics signedness) { 56 if (width > IntegerType::kMaxWidth) { 57 return emitError(loc) << "integer bitwidth is limited to " 58 << IntegerType::kMaxWidth << " bits"; 59 } 60 return success(); 61 } 62 63 unsigned IntegerType::getWidth() const { return getImpl()->width; } 64 65 IntegerType::SignednessSemantics IntegerType::getSignedness() const { 66 return getImpl()->signedness; 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Float Type 71 //===----------------------------------------------------------------------===// 72 73 unsigned FloatType::getWidth() { 74 if (isa<Float16Type, BFloat16Type>()) 75 return 16; 76 if (isa<Float32Type>()) 77 return 32; 78 if (isa<Float64Type>()) 79 return 64; 80 llvm_unreachable("unexpected float type"); 81 } 82 83 /// Returns the floating semantics for the given type. 84 const llvm::fltSemantics &FloatType::getFloatSemantics() { 85 if (isa<BFloat16Type>()) 86 return APFloat::BFloat(); 87 if (isa<Float16Type>()) 88 return APFloat::IEEEhalf(); 89 if (isa<Float32Type>()) 90 return APFloat::IEEEsingle(); 91 if (isa<Float64Type>()) 92 return APFloat::IEEEdouble(); 93 llvm_unreachable("non-floating point type used"); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // FunctionType 98 //===----------------------------------------------------------------------===// 99 100 FunctionType FunctionType::get(TypeRange inputs, TypeRange results, 101 MLIRContext *context) { 102 return Base::get(context, inputs, results); 103 } 104 105 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 106 107 ArrayRef<Type> FunctionType::getInputs() const { 108 return getImpl()->getInputs(); 109 } 110 111 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 112 113 ArrayRef<Type> FunctionType::getResults() const { 114 return getImpl()->getResults(); 115 } 116 117 /// Helper to call a callback once on each index in the range 118 /// [0, `totalIndices`), *except* for the indices given in `indices`. 119 /// `indices` is allowed to have duplicates and can be in any order. 120 inline void iterateIndicesExcept(unsigned totalIndices, 121 ArrayRef<unsigned> indices, 122 function_ref<void(unsigned)> callback) { 123 llvm::BitVector skipIndices(totalIndices); 124 for (unsigned i : indices) 125 skipIndices.set(i); 126 127 for (unsigned i = 0; i < totalIndices; ++i) 128 if (!skipIndices.test(i)) 129 callback(i); 130 } 131 132 /// Returns a new function type without the specified arguments and results. 133 FunctionType 134 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 135 ArrayRef<unsigned> resultIndices) { 136 ArrayRef<Type> newInputTypes = getInputs(); 137 SmallVector<Type, 4> newInputTypesBuffer; 138 if (!argIndices.empty()) { 139 unsigned originalNumArgs = getNumInputs(); 140 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 141 newInputTypesBuffer.emplace_back(getInput(i)); 142 }); 143 newInputTypes = newInputTypesBuffer; 144 } 145 146 ArrayRef<Type> newResultTypes = getResults(); 147 SmallVector<Type, 4> newResultTypesBuffer; 148 if (!resultIndices.empty()) { 149 unsigned originalNumResults = getNumResults(); 150 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 151 newResultTypesBuffer.emplace_back(getResult(i)); 152 }); 153 newResultTypes = newResultTypesBuffer; 154 } 155 156 return get(newInputTypes, newResultTypes, getContext()); 157 } 158 159 //===----------------------------------------------------------------------===// 160 // OpaqueType 161 //===----------------------------------------------------------------------===// 162 163 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData, 164 MLIRContext *context) { 165 return Base::get(context, dialect, typeData); 166 } 167 168 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, 169 MLIRContext *context, Location location) { 170 return Base::getChecked(location, dialect, typeData); 171 } 172 173 /// Returns the dialect namespace of the opaque type. 174 Identifier OpaqueType::getDialectNamespace() const { 175 return getImpl()->dialectNamespace; 176 } 177 178 /// Returns the raw type data of the opaque type. 179 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; } 180 181 /// Verify the construction of an opaque type. 182 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, 183 Identifier dialect, 184 StringRef typeData) { 185 if (!Dialect::isValidNamespace(dialect.strref())) 186 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 187 return success(); 188 } 189 190 //===----------------------------------------------------------------------===// 191 // ShapedType 192 //===----------------------------------------------------------------------===// 193 constexpr int64_t ShapedType::kDynamicSize; 194 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 195 196 Type ShapedType::getElementType() const { 197 return static_cast<ImplType *>(impl)->elementType; 198 } 199 200 unsigned ShapedType::getElementTypeBitWidth() const { 201 return getElementType().getIntOrFloatBitWidth(); 202 } 203 204 int64_t ShapedType::getNumElements() const { 205 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 206 auto shape = getShape(); 207 int64_t num = 1; 208 for (auto dim : shape) 209 num *= dim; 210 return num; 211 } 212 213 int64_t ShapedType::getRank() const { return getShape().size(); } 214 215 bool ShapedType::hasRank() const { 216 return !isa<UnrankedMemRefType, UnrankedTensorType>(); 217 } 218 219 int64_t ShapedType::getDimSize(unsigned idx) const { 220 assert(idx < getRank() && "invalid index for shaped type"); 221 return getShape()[idx]; 222 } 223 224 bool ShapedType::isDynamicDim(unsigned idx) const { 225 assert(idx < getRank() && "invalid index for shaped type"); 226 return isDynamic(getShape()[idx]); 227 } 228 229 unsigned ShapedType::getDynamicDimIndex(unsigned index) const { 230 assert(index < getRank() && "invalid index"); 231 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); 232 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); 233 } 234 235 /// Get the number of bits require to store a value of the given shaped type. 236 /// Compute the value recursively since tensors are allowed to have vectors as 237 /// elements. 238 int64_t ShapedType::getSizeInBits() const { 239 assert(hasStaticShape() && 240 "cannot get the bit size of an aggregate with a dynamic shape"); 241 242 auto elementType = getElementType(); 243 if (elementType.isIntOrFloat()) 244 return elementType.getIntOrFloatBitWidth() * getNumElements(); 245 246 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 247 elementType = complexType.getElementType(); 248 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 249 } 250 251 // Tensors can have vectors and other tensors as elements, other shaped types 252 // cannot. 253 assert(isa<TensorType>() && "unsupported element type"); 254 assert((elementType.isa<VectorType, TensorType>()) && 255 "unsupported tensor element type"); 256 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 257 } 258 259 ArrayRef<int64_t> ShapedType::getShape() const { 260 if (auto vectorType = dyn_cast<VectorType>()) 261 return vectorType.getShape(); 262 if (auto tensorType = dyn_cast<RankedTensorType>()) 263 return tensorType.getShape(); 264 return cast<MemRefType>().getShape(); 265 } 266 267 int64_t ShapedType::getNumDynamicDims() const { 268 return llvm::count_if(getShape(), isDynamic); 269 } 270 271 bool ShapedType::hasStaticShape() const { 272 return hasRank() && llvm::none_of(getShape(), isDynamic); 273 } 274 275 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const { 276 return hasStaticShape() && getShape() == shape; 277 } 278 279 //===----------------------------------------------------------------------===// 280 // VectorType 281 //===----------------------------------------------------------------------===// 282 283 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { 284 return Base::get(elementType.getContext(), shape, elementType); 285 } 286 287 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, 288 Location location) { 289 return Base::getChecked(location, shape, elementType); 290 } 291 292 LogicalResult VectorType::verifyConstructionInvariants(Location loc, 293 ArrayRef<int64_t> shape, 294 Type elementType) { 295 if (shape.empty()) 296 return emitError(loc, "vector types must have at least one dimension"); 297 298 if (!isValidElementType(elementType)) 299 return emitError(loc, "vector elements must be int or float type"); 300 301 if (any_of(shape, [](int64_t i) { return i <= 0; })) 302 return emitError(loc, "vector types must have positive constant sizes"); 303 304 return success(); 305 } 306 307 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } 308 309 //===----------------------------------------------------------------------===// 310 // TensorType 311 //===----------------------------------------------------------------------===// 312 313 // Check if "elementType" can be an element type of a tensor. Emit errors if 314 // location is not nullptr. Returns failure if check failed. 315 static LogicalResult checkTensorElementType(Location location, 316 Type elementType) { 317 if (!TensorType::isValidElementType(elementType)) 318 return emitError(location, "invalid tensor element type: ") << elementType; 319 return success(); 320 } 321 322 /// Return true if the specified element type is ok in a tensor. 323 bool TensorType::isValidElementType(Type type) { 324 // Note: Non standard/builtin types are allowed to exist within tensor 325 // types. Dialects are expected to verify that tensor types have a valid 326 // element type within that dialect. 327 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, 328 IndexType>() || 329 !type.getDialect().getNamespace().empty(); 330 } 331 332 //===----------------------------------------------------------------------===// 333 // RankedTensorType 334 //===----------------------------------------------------------------------===// 335 336 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, 337 Type elementType) { 338 return Base::get(elementType.getContext(), shape, elementType); 339 } 340 341 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, 342 Type elementType, 343 Location location) { 344 return Base::getChecked(location, shape, elementType); 345 } 346 347 LogicalResult RankedTensorType::verifyConstructionInvariants( 348 Location loc, ArrayRef<int64_t> shape, Type elementType) { 349 for (int64_t s : shape) { 350 if (s < -1) 351 return emitError(loc, "invalid tensor dimension size"); 352 } 353 return checkTensorElementType(loc, elementType); 354 } 355 356 ArrayRef<int64_t> RankedTensorType::getShape() const { 357 return getImpl()->getShape(); 358 } 359 360 //===----------------------------------------------------------------------===// 361 // UnrankedTensorType 362 //===----------------------------------------------------------------------===// 363 364 UnrankedTensorType UnrankedTensorType::get(Type elementType) { 365 return Base::get(elementType.getContext(), elementType); 366 } 367 368 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, 369 Location location) { 370 return Base::getChecked(location, elementType); 371 } 372 373 LogicalResult 374 UnrankedTensorType::verifyConstructionInvariants(Location loc, 375 Type elementType) { 376 return checkTensorElementType(loc, elementType); 377 } 378 379 //===----------------------------------------------------------------------===// 380 // BaseMemRefType 381 //===----------------------------------------------------------------------===// 382 383 unsigned BaseMemRefType::getMemorySpace() const { 384 return static_cast<ImplType *>(impl)->memorySpace; 385 } 386 387 //===----------------------------------------------------------------------===// 388 // MemRefType 389 //===----------------------------------------------------------------------===// 390 391 /// Get or create a new MemRefType based on shape, element type, affine 392 /// map composition, and memory space. Assumes the arguments define a 393 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType 394 /// construction failures. 395 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, 396 ArrayRef<AffineMap> affineMapComposition, 397 unsigned memorySpace) { 398 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, 399 /*location=*/llvm::None); 400 assert(result && "Failed to construct instance of MemRefType."); 401 return result; 402 } 403 404 /// Get or create a new MemRefType based on shape, element type, affine 405 /// map composition, and memory space declared at the given location. 406 /// If the location is unknown, the last argument should be an instance of 407 /// UnknownLoc. If the MemRefType defined by the arguments would be 408 /// ill-formed, emits errors (to the handler registered with the context or to 409 /// the error stream) and returns nullptr. 410 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType, 411 ArrayRef<AffineMap> affineMapComposition, 412 unsigned memorySpace, Location location) { 413 return getImpl(shape, elementType, affineMapComposition, memorySpace, 414 location); 415 } 416 417 /// Get or create a new MemRefType defined by the arguments. If the resulting 418 /// type would be ill-formed, return nullptr. If the location is provided, 419 /// emit detailed error messages. To emit errors when the location is unknown, 420 /// pass in an instance of UnknownLoc. 421 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, 422 ArrayRef<AffineMap> affineMapComposition, 423 unsigned memorySpace, 424 Optional<Location> location) { 425 auto *context = elementType.getContext(); 426 427 if (!BaseMemRefType::isValidElementType(elementType)) 428 return emitOptionalError(location, "invalid memref element type"), 429 MemRefType(); 430 431 for (int64_t s : shape) { 432 // Negative sizes are not allowed except for `-1` that means dynamic size. 433 if (s < -1) 434 return emitOptionalError(location, "invalid memref size"), MemRefType(); 435 } 436 437 // Check that the structure of the composition is valid, i.e. that each 438 // subsequent affine map has as many inputs as the previous map has results. 439 // Take the dimensionality of the MemRef for the first map. 440 auto dim = shape.size(); 441 unsigned i = 0; 442 for (const auto &affineMap : affineMapComposition) { 443 if (affineMap.getNumDims() != dim) { 444 if (location) 445 emitError(*location) 446 << "memref affine map dimension mismatch between " 447 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) 448 << " and affine map" << i + 1 << ": " << dim 449 << " != " << affineMap.getNumDims(); 450 return nullptr; 451 } 452 453 dim = affineMap.getNumResults(); 454 ++i; 455 } 456 457 // Drop identity maps from the composition. 458 // This may lead to the composition becoming empty, which is interpreted as an 459 // implicit identity. 460 SmallVector<AffineMap, 2> cleanedAffineMapComposition; 461 for (const auto &map : affineMapComposition) { 462 if (map.isIdentity()) 463 continue; 464 cleanedAffineMapComposition.push_back(map); 465 } 466 467 return Base::get(context, shape, elementType, cleanedAffineMapComposition, 468 memorySpace); 469 } 470 471 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } 472 473 ArrayRef<AffineMap> MemRefType::getAffineMaps() const { 474 return getImpl()->getAffineMaps(); 475 } 476 477 //===----------------------------------------------------------------------===// 478 // UnrankedMemRefType 479 //===----------------------------------------------------------------------===// 480 481 UnrankedMemRefType UnrankedMemRefType::get(Type elementType, 482 unsigned memorySpace) { 483 return Base::get(elementType.getContext(), elementType, memorySpace); 484 } 485 486 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, 487 unsigned memorySpace, 488 Location location) { 489 return Base::getChecked(location, elementType, memorySpace); 490 } 491 492 LogicalResult 493 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, 494 unsigned memorySpace) { 495 if (!BaseMemRefType::isValidElementType(elementType)) 496 return emitError(loc, "invalid memref element type"); 497 return success(); 498 } 499 500 // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( 501 // i.e. single term). Accumulate the AffineExpr into the existing one. 502 static void extractStridesFromTerm(AffineExpr e, 503 AffineExpr multiplicativeFactor, 504 MutableArrayRef<AffineExpr> strides, 505 AffineExpr &offset) { 506 if (auto dim = e.dyn_cast<AffineDimExpr>()) 507 strides[dim.getPosition()] = 508 strides[dim.getPosition()] + multiplicativeFactor; 509 else 510 offset = offset + e * multiplicativeFactor; 511 } 512 513 /// Takes a single AffineExpr `e` and populates the `strides` array with the 514 /// strides expressions for each dim position. 515 /// The convention is that the strides for dimensions d0, .. dn appear in 516 /// order to make indexing intuitive into the result. 517 static LogicalResult extractStrides(AffineExpr e, 518 AffineExpr multiplicativeFactor, 519 MutableArrayRef<AffineExpr> strides, 520 AffineExpr &offset) { 521 auto bin = e.dyn_cast<AffineBinaryOpExpr>(); 522 if (!bin) { 523 extractStridesFromTerm(e, multiplicativeFactor, strides, offset); 524 return success(); 525 } 526 527 if (bin.getKind() == AffineExprKind::CeilDiv || 528 bin.getKind() == AffineExprKind::FloorDiv || 529 bin.getKind() == AffineExprKind::Mod) 530 return failure(); 531 532 if (bin.getKind() == AffineExprKind::Mul) { 533 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); 534 if (dim) { 535 strides[dim.getPosition()] = 536 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; 537 return success(); 538 } 539 // LHS and RHS may both contain complex expressions of dims. Try one path 540 // and if it fails try the other. This is guaranteed to succeed because 541 // only one path may have a `dim`, otherwise this is not an AffineExpr in 542 // the first place. 543 if (bin.getLHS().isSymbolicOrConstant()) 544 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), 545 strides, offset); 546 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), 547 strides, offset); 548 } 549 550 if (bin.getKind() == AffineExprKind::Add) { 551 auto res1 = 552 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); 553 auto res2 = 554 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); 555 return success(succeeded(res1) && succeeded(res2)); 556 } 557 558 llvm_unreachable("unexpected binary operation"); 559 } 560 561 LogicalResult mlir::getStridesAndOffset(MemRefType t, 562 SmallVectorImpl<AffineExpr> &strides, 563 AffineExpr &offset) { 564 auto affineMaps = t.getAffineMaps(); 565 // For now strides are only computed on a single affine map with a single 566 // result (i.e. the closed subset of linearization maps that are compatible 567 // with striding semantics). 568 // TODO: support more forms on a per-need basis. 569 if (affineMaps.size() > 1) 570 return failure(); 571 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) 572 return failure(); 573 574 auto zero = getAffineConstantExpr(0, t.getContext()); 575 auto one = getAffineConstantExpr(1, t.getContext()); 576 offset = zero; 577 strides.assign(t.getRank(), zero); 578 579 AffineMap m; 580 if (!affineMaps.empty()) { 581 m = affineMaps.front(); 582 assert(!m.isIdentity() && "unexpected identity map"); 583 } 584 585 // Canonical case for empty map. 586 if (!m) { 587 // 0-D corner case, offset is already 0. 588 if (t.getRank() == 0) 589 return success(); 590 auto stridedExpr = 591 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 592 if (succeeded(extractStrides(stridedExpr, one, strides, offset))) 593 return success(); 594 assert(false && "unexpected failure: extract strides in canonical layout"); 595 } 596 597 // Non-canonical case requires more work. 598 auto stridedExpr = 599 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 600 if (failed(extractStrides(stridedExpr, one, strides, offset))) { 601 offset = AffineExpr(); 602 strides.clear(); 603 return failure(); 604 } 605 606 // Simplify results to allow folding to constants and simple checks. 607 unsigned numDims = m.getNumDims(); 608 unsigned numSymbols = m.getNumSymbols(); 609 offset = simplifyAffineExpr(offset, numDims, numSymbols); 610 for (auto &stride : strides) 611 stride = simplifyAffineExpr(stride, numDims, numSymbols); 612 613 /// In practice, a strided memref must be internally non-aliasing. Test 614 /// against 0 as a proxy. 615 /// TODO: static cases can have more advanced checks. 616 /// TODO: dynamic cases would require a way to compare symbolic 617 /// expressions and would probably need an affine set context propagated 618 /// everywhere. 619 if (llvm::any_of(strides, [](AffineExpr e) { 620 return e == getAffineConstantExpr(0, e.getContext()); 621 })) { 622 offset = AffineExpr(); 623 strides.clear(); 624 return failure(); 625 } 626 627 return success(); 628 } 629 630 LogicalResult mlir::getStridesAndOffset(MemRefType t, 631 SmallVectorImpl<int64_t> &strides, 632 int64_t &offset) { 633 AffineExpr offsetExpr; 634 SmallVector<AffineExpr, 4> strideExprs; 635 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) 636 return failure(); 637 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) 638 offset = cst.getValue(); 639 else 640 offset = ShapedType::kDynamicStrideOrOffset; 641 for (auto e : strideExprs) { 642 if (auto c = e.dyn_cast<AffineConstantExpr>()) 643 strides.push_back(c.getValue()); 644 else 645 strides.push_back(ShapedType::kDynamicStrideOrOffset); 646 } 647 return success(); 648 } 649 650 //===----------------------------------------------------------------------===// 651 /// TupleType 652 //===----------------------------------------------------------------------===// 653 654 /// Get or create a new TupleType with the provided element types. Assumes the 655 /// arguments define a well-formed type. 656 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) { 657 return Base::get(context, elementTypes); 658 } 659 660 /// Get or create an empty tuple type. 661 TupleType TupleType::get(MLIRContext *context) { return get({}, context); } 662 663 /// Return the elements types for this tuple. 664 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } 665 666 /// Accumulate the types contained in this tuple and tuples nested within it. 667 /// Note that this only flattens nested tuples, not any other container type, 668 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to 669 /// (i32, tensor<i32>, f32, i64) 670 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { 671 for (Type type : getTypes()) { 672 if (auto nestedTuple = type.dyn_cast<TupleType>()) 673 nestedTuple.getFlattenedTypes(types); 674 else 675 types.push_back(type); 676 } 677 } 678 679 /// Return the number of element types. 680 size_t TupleType::size() const { return getImpl()->size(); } 681 682 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, 683 int64_t offset, 684 MLIRContext *context) { 685 AffineExpr expr; 686 unsigned nSymbols = 0; 687 688 // AffineExpr for offset. 689 // Static case. 690 if (offset != MemRefType::getDynamicStrideOrOffset()) { 691 auto cst = getAffineConstantExpr(offset, context); 692 expr = cst; 693 } else { 694 // Dynamic case, new symbol for the offset. 695 auto sym = getAffineSymbolExpr(nSymbols++, context); 696 expr = sym; 697 } 698 699 // AffineExpr for strides. 700 for (auto en : llvm::enumerate(strides)) { 701 auto dim = en.index(); 702 auto stride = en.value(); 703 assert(stride != 0 && "Invalid stride specification"); 704 auto d = getAffineDimExpr(dim, context); 705 AffineExpr mult; 706 // Static case. 707 if (stride != MemRefType::getDynamicStrideOrOffset()) 708 mult = getAffineConstantExpr(stride, context); 709 else 710 // Dynamic case, new symbol for each new stride. 711 mult = getAffineSymbolExpr(nSymbols++, context); 712 expr = expr + d * mult; 713 } 714 715 return AffineMap::get(strides.size(), nSymbols, expr); 716 } 717 718 /// Return a version of `t` with identity layout if it can be determined 719 /// statically that the layout is the canonical contiguous strided layout. 720 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of 721 /// `t` with simplified layout. 722 /// If `t` has multiple layout maps or a multi-result layout, just return `t`. 723 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { 724 auto affineMaps = t.getAffineMaps(); 725 // Already in canonical form. 726 if (affineMaps.empty()) 727 return t; 728 729 // Can't reduce to canonical identity form, return in canonical form. 730 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) 731 return t; 732 733 // If the canonical strided layout for the sizes of `t` is equal to the 734 // simplified layout of `t` we can just return an empty layout. Otherwise, 735 // just simplify the existing layout. 736 AffineExpr expr = 737 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); 738 auto m = affineMaps[0]; 739 auto simplifiedLayoutExpr = 740 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); 741 if (expr != simplifiedLayoutExpr) 742 return MemRefType::Builder(t).setAffineMaps({AffineMap::get( 743 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)}); 744 return MemRefType::Builder(t).setAffineMaps({}); 745 } 746 747 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 748 ArrayRef<AffineExpr> exprs, 749 MLIRContext *context) { 750 // Size 0 corner case is useful for canonicalizations. 751 if (llvm::is_contained(sizes, 0)) 752 return getAffineConstantExpr(0, context); 753 754 auto maps = AffineMap::inferFromExprList(exprs); 755 assert(!maps.empty() && "Expected one non-empty map"); 756 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); 757 758 AffineExpr expr; 759 bool dynamicPoisonBit = false; 760 int64_t runningSize = 1; 761 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { 762 int64_t size = std::get<1>(en); 763 // Degenerate case, no size =-> no stride 764 if (size == 0) 765 continue; 766 AffineExpr dimExpr = std::get<0>(en); 767 AffineExpr stride = dynamicPoisonBit 768 ? getAffineSymbolExpr(nSymbols++, context) 769 : getAffineConstantExpr(runningSize, context); 770 expr = expr ? expr + dimExpr * stride : dimExpr * stride; 771 if (size > 0) 772 runningSize *= size; 773 else 774 dynamicPoisonBit = true; 775 } 776 return simplifyAffineExpr(expr, numDims, nSymbols); 777 } 778 779 /// Return a version of `t` with a layout that has all dynamic offset and 780 /// strides. This is used to erase the static layout. 781 MemRefType mlir::eraseStridedLayout(MemRefType t) { 782 auto val = ShapedType::kDynamicStrideOrOffset; 783 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap( 784 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext())); 785 } 786 787 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, 788 MLIRContext *context) { 789 SmallVector<AffineExpr, 4> exprs; 790 exprs.reserve(sizes.size()); 791 for (auto dim : llvm::seq<unsigned>(0, sizes.size())) 792 exprs.push_back(getAffineDimExpr(dim, context)); 793 return makeCanonicalStridedLayoutExpr(sizes, exprs, context); 794 } 795 796 /// Return true if the layout for `t` is compatible with strided semantics. 797 bool mlir::isStrided(MemRefType t) { 798 int64_t offset; 799 SmallVector<int64_t, 4> stridesAndOffset; 800 auto res = getStridesAndOffset(t, stridesAndOffset, offset); 801 return succeeded(res); 802 } 803