//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinTypes.h" #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// /// Tablegen Type Definitions //===----------------------------------------------------------------------===// #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.cpp.inc" //===----------------------------------------------------------------------===// // BuiltinDialect //===----------------------------------------------------------------------===// void BuiltinDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/IR/BuiltinTypes.cpp.inc" >(); } //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// /// Verify the construction of an integer type. LogicalResult ComplexType::verify(function_ref emitError, Type elementType) { if (!elementType.isIntOrFloat()) return emitError() << "invalid element type for complex"; return success(); } //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// // static constexpr must have a definition (until in C++17 and inline variable). constexpr unsigned IntegerType::kMaxWidth; /// Verify the construction of an integer type. LogicalResult IntegerType::verify(function_ref emitError, unsigned width, SignednessSemantics signedness) { if (width > IntegerType::kMaxWidth) { return emitError() << "integer bitwidth is limited to " << IntegerType::kMaxWidth << " bits"; } return success(); } unsigned IntegerType::getWidth() const { return getImpl()->width; } IntegerType::SignednessSemantics IntegerType::getSignedness() const { return getImpl()->signedness; } IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { if (!scale) return IntegerType(); return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); } //===----------------------------------------------------------------------===// // Float Type //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { if (isa()) return 16; if (isa()) return 32; if (isa()) return 64; if (isa()) return 80; if (isa()) return 128; llvm_unreachable("unexpected float type"); } /// Returns the floating semantics for the given type. const llvm::fltSemantics &FloatType::getFloatSemantics() { if (isa()) return APFloat::BFloat(); if (isa()) return APFloat::IEEEhalf(); if (isa()) return APFloat::IEEEsingle(); if (isa()) return APFloat::IEEEdouble(); if (isa()) return APFloat::x87DoubleExtended(); if (isa()) return APFloat::IEEEquad(); llvm_unreachable("non-floating point type used"); } FloatType FloatType::scaleElementBitwidth(unsigned scale) { if (!scale) return FloatType(); MLIRContext *ctx = getContext(); if (isF16() || isBF16()) { if (scale == 2) return FloatType::getF32(ctx); if (scale == 4) return FloatType::getF64(ctx); } if (isF32()) if (scale == 2) return FloatType::getF64(ctx); return FloatType(); } unsigned FloatType::getFPMantissaWidth() { return APFloat::semanticsPrecision(getFloatSemantics()); } //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } ArrayRef FunctionType::getInputs() const { return getImpl()->getInputs(); } unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } ArrayRef FunctionType::getResults() const { return getImpl()->getResults(); } FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { return get(getContext(), inputs, results); } /// Returns a new function type with the specified arguments and results /// inserted. FunctionType FunctionType::getWithArgsAndResults( ArrayRef argIndices, TypeRange argTypes, ArrayRef resultIndices, TypeRange resultTypes) { SmallVector argStorage, resultStorage; TypeRange newArgTypes = function_interface_impl::insertTypesInto( getInputs(), argIndices, argTypes, argStorage); TypeRange newResultTypes = function_interface_impl::insertTypesInto( getResults(), resultIndices, resultTypes, resultStorage); return clone(newArgTypes, newResultTypes); } /// Returns a new function type without the specified arguments and results. FunctionType FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, const BitVector &resultIndices) { SmallVector argStorage, resultStorage; TypeRange newArgTypes = function_interface_impl::filterTypesOut( getInputs(), argIndices, argStorage); TypeRange newResultTypes = function_interface_impl::filterTypesOut( getResults(), resultIndices, resultStorage); return clone(newArgTypes, newResultTypes); } void FunctionType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { for (Type type : llvm::concat(getInputs(), getResults())) walkTypesFn(type); } Type FunctionType::replaceImmediateSubElements(ArrayRef replAttrs, ArrayRef replTypes) const { unsigned numInputs = getNumInputs(); return get(getContext(), replTypes.take_front(numInputs), replTypes.drop_front(numInputs)); } //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// /// Verify the construction of an opaque type. LogicalResult OpaqueType::verify(function_ref emitError, StringAttr dialect, StringRef typeData) { if (!Dialect::isValidNamespace(dialect.strref())) return emitError() << "invalid dialect namespace '" << dialect << "'"; // Check that the dialect is actually registered. MLIRContext *context = dialect.getContext(); if (!context->allowsUnregisteredDialects() && !context->getLoadedDialect(dialect.strref())) { return emitError() << "`!" << dialect << "<\"" << typeData << "\">" << "` type created with unregistered dialect. If this is " "intended, please call allowUnregisteredDialects() on the " "MLIRContext, or use -allow-unregistered-dialect with " "the MLIR opt tool used"; } return success(); } //===----------------------------------------------------------------------===// // VectorType //===----------------------------------------------------------------------===// LogicalResult VectorType::verify(function_ref emitError, ArrayRef shape, Type elementType, unsigned numScalableDims) { if (!isValidElementType(elementType)) return emitError() << "vector elements must be int/index/float type but got " << elementType; if (any_of(shape, [](int64_t i) { return i <= 0; })) return emitError() << "vector types must have positive constant sizes but got " << shape; return success(); } VectorType VectorType::scaleElementBitwidth(unsigned scale) { if (!scale) return VectorType(); if (auto et = getElementType().dyn_cast()) if (auto scaledEt = et.scaleElementBitwidth(scale)) return VectorType::get(getShape(), scaledEt, getNumScalableDims()); if (auto et = getElementType().dyn_cast()) if (auto scaledEt = et.scaleElementBitwidth(scale)) return VectorType::get(getShape(), scaledEt, getNumScalableDims()); return VectorType(); } void VectorType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { walkTypesFn(getElementType()); } Type VectorType::replaceImmediateSubElements(ArrayRef replAttrs, ArrayRef replTypes) const { return get(getShape(), replTypes.front(), getNumScalableDims()); } VectorType VectorType::cloneWith(Optional> shape, Type elementType) const { return VectorType::get(shape.value_or(getShape()), elementType, getNumScalableDims()); } //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// Type TensorType::getElementType() const { return llvm::TypeSwitch(*this) .Case( [](auto type) { return type.getElementType(); }); } bool TensorType::hasRank() const { return !isa(); } ArrayRef TensorType::getShape() const { return cast().getShape(); } TensorType TensorType::cloneWith(Optional> shape, Type elementType) const { if (auto unrankedTy = dyn_cast()) { if (shape) return RankedTensorType::get(*shape, elementType); return UnrankedTensorType::get(elementType); } auto rankedTy = cast(); if (!shape) return RankedTensorType::get(rankedTy.getShape(), elementType, rankedTy.getEncoding()); return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, rankedTy.getEncoding()); } // Check if "elementType" can be an element type of a tensor. static LogicalResult checkTensorElementType(function_ref emitError, Type elementType) { if (!TensorType::isValidElementType(elementType)) return emitError() << "invalid tensor element type: " << elementType; return success(); } /// Return true if the specified element type is ok in a tensor. bool TensorType::isValidElementType(Type type) { // Note: Non standard/builtin types are allowed to exist within tensor // types. Dialects are expected to verify that tensor types have a valid // element type within that dialect. return type.isa() || !llvm::isa(type.getDialect()); } //===----------------------------------------------------------------------===// // RankedTensorType //===----------------------------------------------------------------------===// LogicalResult RankedTensorType::verify(function_ref emitError, ArrayRef shape, Type elementType, Attribute encoding) { for (int64_t s : shape) if (s < -1) return emitError() << "invalid tensor dimension size"; if (auto v = encoding.dyn_cast_or_null()) if (failed(v.verifyEncoding(shape, elementType, emitError))) return failure(); return checkTensorElementType(emitError, elementType); } void RankedTensorType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { walkTypesFn(getElementType()); if (Attribute encoding = getEncoding()) walkAttrsFn(encoding); } Type RankedTensorType::replaceImmediateSubElements( ArrayRef replAttrs, ArrayRef replTypes) const { return get(getShape(), replTypes.front(), replAttrs.empty() ? Attribute() : replAttrs.back()); } //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// LogicalResult UnrankedTensorType::verify(function_ref emitError, Type elementType) { return checkTensorElementType(emitError, elementType); } void UnrankedTensorType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { walkTypesFn(getElementType()); } Type UnrankedTensorType::replaceImmediateSubElements( ArrayRef replAttrs, ArrayRef replTypes) const { return get(replTypes.front()); } //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// Type BaseMemRefType::getElementType() const { return llvm::TypeSwitch(*this) .Case( [](auto type) { return type.getElementType(); }); } bool BaseMemRefType::hasRank() const { return !isa(); } ArrayRef BaseMemRefType::getShape() const { return cast().getShape(); } BaseMemRefType BaseMemRefType::cloneWith(Optional> shape, Type elementType) const { if (auto unrankedTy = dyn_cast()) { if (!shape) return UnrankedMemRefType::get(elementType, getMemorySpace()); MemRefType::Builder builder(*shape, elementType); builder.setMemorySpace(getMemorySpace()); return builder; } MemRefType::Builder builder(cast()); if (shape) builder.setShape(*shape); builder.setElementType(elementType); return builder; } Attribute BaseMemRefType::getMemorySpace() const { if (auto rankedMemRefTy = dyn_cast()) return rankedMemRefTy.getMemorySpace(); return cast().getMemorySpace(); } unsigned BaseMemRefType::getMemorySpaceAsInt() const { if (auto rankedMemRefTy = dyn_cast()) return rankedMemRefTy.getMemorySpaceAsInt(); return cast().getMemorySpaceAsInt(); } //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// /// Given an `originalShape` and a `reducedShape` assumed to be a subset of /// `originalShape` with some `1` entries erased, return the set of indices /// that specifies which of the entries of `originalShape` are dropped to obtain /// `reducedShape`. The returned mask can be applied as a projection to /// `originalShape` to obtain the `reducedShape`. This mask is useful to track /// which dimensions must be kept when e.g. compute MemRef strides under /// rank-reducing operations. Return None if reducedShape cannot be obtained /// by dropping only `1` entries in `originalShape`. llvm::Optional> mlir::computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape) { size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); llvm::SmallDenseSet unusedDims; unsigned reducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { // Greedily insert `originalIdx` if match. if (reducedIdx < reducedRank && originalShape[originalIdx] == reducedShape[reducedIdx]) { reducedIdx++; continue; } unusedDims.insert(originalIdx); // If no match on `originalIdx`, the `originalShape` at this dimension // must be 1, otherwise we bail. if (originalShape[originalIdx] != 1) return llvm::None; } // The whole reducedShape must be scanned, otherwise we bail. if (reducedIdx != reducedRank) return llvm::None; return unusedDims; } SliceVerificationResult mlir::isRankReducedType(ShapedType originalType, ShapedType candidateReducedType) { if (originalType == candidateReducedType) return SliceVerificationResult::Success; ShapedType originalShapedType = originalType.cast(); ShapedType candidateReducedShapedType = candidateReducedType.cast(); // Rank and size logic is valid for all ShapedTypes. ArrayRef originalShape = originalShapedType.getShape(); ArrayRef candidateReducedShape = candidateReducedShapedType.getShape(); unsigned originalRank = originalShape.size(), candidateReducedRank = candidateReducedShape.size(); if (candidateReducedRank > originalRank) return SliceVerificationResult::RankTooLarge; auto optionalUnusedDimsMask = computeRankReductionMask(originalShape, candidateReducedShape); // Sizes cannot be matched in case empty vector is returned. if (!optionalUnusedDimsMask) return SliceVerificationResult::SizeMismatch; if (originalShapedType.getElementType() != candidateReducedShapedType.getElementType()) return SliceVerificationResult::ElemTypeMismatch; return SliceVerificationResult::Success; } bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { // Empty attribute is allowed as default memory space. if (!memorySpace) return true; // Supported built-in attributes. if (memorySpace.isa()) return true; // Allow custom dialect attributes. if (!isa(memorySpace.getDialect())) return true; return false; } Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx) { if (memorySpace == 0) return nullptr; return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); } Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null(); if (intMemorySpace && intMemorySpace.getValue() == 0) return nullptr; return memorySpace; } unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { if (!memorySpace) return 0; assert(memorySpace.isa() && "Using `getMemorySpaceInteger` with non-Integer attribute"); return static_cast(memorySpace.cast().getInt()); } MemRefType::Builder & MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { memorySpace = wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); return *this; } unsigned MemRefType::getMemorySpaceAsInt() const { return detail::getMemorySpaceAsInt(getMemorySpace()); } MemRefType MemRefType::get(ArrayRef shape, Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { // Use default layout for empty attribute. if (!layout) layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( shape.size(), elementType.getContext())); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); return Base::get(elementType.getContext(), shape, elementType, layout, memorySpace); } MemRefType MemRefType::getChecked( function_ref emitErrorFn, ArrayRef shape, Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { // Use default layout for empty attribute. if (!layout) layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( shape.size(), elementType.getContext())); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); return Base::getChecked(emitErrorFn, elementType.getContext(), shape, elementType, layout, memorySpace); } MemRefType MemRefType::get(ArrayRef shape, Type elementType, AffineMap map, Attribute memorySpace) { // Use default layout for empty map. if (!map) map = AffineMap::getMultiDimIdentityMap(shape.size(), elementType.getContext()); // Wrap AffineMap into Attribute. Attribute layout = AffineMapAttr::get(map); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); return Base::get(elementType.getContext(), shape, elementType, layout, memorySpace); } MemRefType MemRefType::getChecked(function_ref emitErrorFn, ArrayRef shape, Type elementType, AffineMap map, Attribute memorySpace) { // Use default layout for empty map. if (!map) map = AffineMap::getMultiDimIdentityMap(shape.size(), elementType.getContext()); // Wrap AffineMap into Attribute. Attribute layout = AffineMapAttr::get(map); // Drop default memory space value and replace it with empty attribute. memorySpace = skipDefaultMemorySpace(memorySpace); return Base::getChecked(emitErrorFn, elementType.getContext(), shape, elementType, layout, memorySpace); } MemRefType MemRefType::get(ArrayRef shape, Type elementType, AffineMap map, unsigned memorySpaceInd) { // Use default layout for empty map. if (!map) map = AffineMap::getMultiDimIdentityMap(shape.size(), elementType.getContext()); // Wrap AffineMap into Attribute. Attribute layout = AffineMapAttr::get(map); // Convert deprecated integer-like memory space to Attribute. Attribute memorySpace = wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); return Base::get(elementType.getContext(), shape, elementType, layout, memorySpace); } MemRefType MemRefType::getChecked(function_ref emitErrorFn, ArrayRef shape, Type elementType, AffineMap map, unsigned memorySpaceInd) { // Use default layout for empty map. if (!map) map = AffineMap::getMultiDimIdentityMap(shape.size(), elementType.getContext()); // Wrap AffineMap into Attribute. Attribute layout = AffineMapAttr::get(map); // Convert deprecated integer-like memory space to Attribute. Attribute memorySpace = wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); return Base::getChecked(emitErrorFn, elementType.getContext(), shape, elementType, layout, memorySpace); } LogicalResult MemRefType::verify(function_ref emitError, ArrayRef shape, Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; // Negative sizes are not allowed except for `-1` that means dynamic size. for (int64_t s : shape) if (s < -1) return emitError() << "invalid memref size"; assert(layout && "missing layout specification"); if (failed(layout.verifyLayout(shape, emitError))) return failure(); if (!isSupportedMemorySpace(memorySpace)) return emitError() << "unsupported memory space Attribute"; return success(); } void MemRefType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { walkTypesFn(getElementType()); if (!getLayout().isIdentity()) walkAttrsFn(getLayout()); walkAttrsFn(getMemorySpace()); } Type MemRefType::replaceImmediateSubElements(ArrayRef replAttrs, ArrayRef replTypes) const { bool hasLayout = replAttrs.size() > 1; return get(getShape(), replTypes[0], hasLayout ? replAttrs[0].dyn_cast() : MemRefLayoutAttrInterface(), hasLayout ? replAttrs[1] : replAttrs[0]); } //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { return detail::getMemorySpaceAsInt(getMemorySpace()); } LogicalResult UnrankedMemRefType::verify(function_ref emitError, Type elementType, Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; if (!isSupportedMemorySpace(memorySpace)) return emitError() << "unsupported memory space Attribute"; return success(); } // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( // i.e. single term). Accumulate the AffineExpr into the existing one. static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef strides, AffineExpr &offset) { if (auto dim = e.dyn_cast()) strides[dim.getPosition()] = strides[dim.getPosition()] + multiplicativeFactor; else offset = offset + e * multiplicativeFactor; } /// Takes a single AffineExpr `e` and populates the `strides` array with the /// strides expressions for each dim position. /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef strides, AffineExpr &offset) { auto bin = e.dyn_cast(); if (!bin) { extractStridesFromTerm(e, multiplicativeFactor, strides, offset); return success(); } if (bin.getKind() == AffineExprKind::CeilDiv || bin.getKind() == AffineExprKind::FloorDiv || bin.getKind() == AffineExprKind::Mod) return failure(); if (bin.getKind() == AffineExprKind::Mul) { auto dim = bin.getLHS().dyn_cast(); if (dim) { strides[dim.getPosition()] = strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; return success(); } // LHS and RHS may both contain complex expressions of dims. Try one path // and if it fails try the other. This is guaranteed to succeed because // only one path may have a `dim`, otherwise this is not an AffineExpr in // the first place. if (bin.getLHS().isSymbolicOrConstant()) return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), strides, offset); return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), strides, offset); } if (bin.getKind() == AffineExprKind::Add) { auto res1 = extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); auto res2 = extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); return success(succeeded(res1) && succeeded(res2)); } llvm_unreachable("unexpected binary operation"); } LogicalResult mlir::getStridesAndOffset(MemRefType t, SmallVectorImpl &strides, AffineExpr &offset) { AffineMap m = t.getLayout().getAffineMap(); if (m.getNumResults() != 1 && !m.isIdentity()) return failure(); auto zero = getAffineConstantExpr(0, t.getContext()); auto one = getAffineConstantExpr(1, t.getContext()); offset = zero; strides.assign(t.getRank(), zero); // Canonical case for empty map. if (m.isIdentity()) { // 0-D corner case, offset is already 0. if (t.getRank() == 0) return success(); auto stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); if (succeeded(extractStrides(stridedExpr, one, strides, offset))) return success(); assert(false && "unexpected failure: extract strides in canonical layout"); } // Non-canonical case requires more work. auto stridedExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (failed(extractStrides(stridedExpr, one, strides, offset))) { offset = AffineExpr(); strides.clear(); return failure(); } // Simplify results to allow folding to constants and simple checks. unsigned numDims = m.getNumDims(); unsigned numSymbols = m.getNumSymbols(); offset = simplifyAffineExpr(offset, numDims, numSymbols); for (auto &stride : strides) stride = simplifyAffineExpr(stride, numDims, numSymbols); /// In practice, a strided memref must be internally non-aliasing. Test /// against 0 as a proxy. /// TODO: static cases can have more advanced checks. /// TODO: dynamic cases would require a way to compare symbolic /// expressions and would probably need an affine set context propagated /// everywhere. if (llvm::any_of(strides, [](AffineExpr e) { return e == getAffineConstantExpr(0, e.getContext()); })) { offset = AffineExpr(); strides.clear(); return failure(); } return success(); } LogicalResult mlir::getStridesAndOffset(MemRefType t, SmallVectorImpl &strides, int64_t &offset) { AffineExpr offsetExpr; SmallVector strideExprs; if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) return failure(); if (auto cst = offsetExpr.dyn_cast()) offset = cst.getValue(); else offset = ShapedType::kDynamicStrideOrOffset; for (auto e : strideExprs) { if (auto c = e.dyn_cast()) strides.push_back(c.getValue()); else strides.push_back(ShapedType::kDynamicStrideOrOffset); } return success(); } void UnrankedMemRefType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { walkTypesFn(getElementType()); walkAttrsFn(getMemorySpace()); } Type UnrankedMemRefType::replaceImmediateSubElements( ArrayRef replAttrs, ArrayRef replTypes) const { return get(replTypes.front(), replAttrs.front()); } //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// /// Return the elements types for this tuple. ArrayRef TupleType::getTypes() const { return getImpl()->getTypes(); } /// Accumulate the types contained in this tuple and tuples nested within it. /// Note that this only flattens nested tuples, not any other container type, /// e.g. a tuple, tuple>> is flattened to /// (i32, tensor, f32, i64) void TupleType::getFlattenedTypes(SmallVectorImpl &types) { for (Type type : getTypes()) { if (auto nestedTuple = type.dyn_cast()) nestedTuple.getFlattenedTypes(types); else types.push_back(type); } } /// Return the number of element types. size_t TupleType::size() const { return getImpl()->size(); } void TupleType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { for (Type type : getTypes()) walkTypesFn(type); } Type TupleType::replaceImmediateSubElements(ArrayRef replAttrs, ArrayRef replTypes) const { return get(getContext(), replTypes); } //===----------------------------------------------------------------------===// // Type Utilities //===----------------------------------------------------------------------===// AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef strides, int64_t offset, MLIRContext *context) { AffineExpr expr; unsigned nSymbols = 0; // AffineExpr for offset. // Static case. if (offset != MemRefType::getDynamicStrideOrOffset()) { auto cst = getAffineConstantExpr(offset, context); expr = cst; } else { // Dynamic case, new symbol for the offset. auto sym = getAffineSymbolExpr(nSymbols++, context); expr = sym; } // AffineExpr for strides. for (const auto &en : llvm::enumerate(strides)) { auto dim = en.index(); auto stride = en.value(); assert(stride != 0 && "Invalid stride specification"); auto d = getAffineDimExpr(dim, context); AffineExpr mult; // Static case. if (stride != MemRefType::getDynamicStrideOrOffset()) mult = getAffineConstantExpr(stride, context); else // Dynamic case, new symbol for each new stride. mult = getAffineSymbolExpr(nSymbols++, context); expr = expr + d * mult; } return AffineMap::get(strides.size(), nSymbols, expr); } /// Return a version of `t` with identity layout if it can be determined /// statically that the layout is the canonical contiguous strided layout. /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of /// `t` with simplified layout. /// If `t` has multiple layout maps or a multi-result layout, just return `t`. MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { AffineMap m = t.getLayout().getAffineMap(); // Already in canonical form. if (m.isIdentity()) return t; // Can't reduce to canonical identity form, return in canonical form. if (m.getNumResults() > 1) return t; // Corner-case for 0-D affine maps. if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { if (auto cst = m.getResult(0).dyn_cast()) if (cst.getValue() == 0) return MemRefType::Builder(t).setLayout({}); return t; } // 0-D corner case for empty shape that still have an affine map. Example: // `memref (s0)>>`. This is a 1 element memref whose // offset needs to remain, just return t. if (t.getShape().empty()) return t; // If the canonical strided layout for the sizes of `t` is equal to the // simplified layout of `t` we can just return an empty layout. Otherwise, // just simplify the existing layout. AffineExpr expr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); auto simplifiedLayoutExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (expr != simplifiedLayoutExpr) return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); return MemRefType::Builder(t).setLayout({}); } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, ArrayRef exprs, MLIRContext *context) { // Size 0 corner case is useful for canonicalizations. if (sizes.empty() || llvm::is_contained(sizes, 0)) return getAffineConstantExpr(0, context); assert(!exprs.empty() && "expected exprs"); auto maps = AffineMap::inferFromExprList(exprs); assert(!maps.empty() && "Expected one non-empty map"); unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); AffineExpr expr; bool dynamicPoisonBit = false; int64_t runningSize = 1; for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { int64_t size = std::get<1>(en); // Degenerate case, no size =-> no stride if (size == 0) continue; AffineExpr dimExpr = std::get<0>(en); AffineExpr stride = dynamicPoisonBit ? getAffineSymbolExpr(nSymbols++, context) : getAffineConstantExpr(runningSize, context); expr = expr ? expr + dimExpr * stride : dimExpr * stride; if (size > 0) { runningSize *= size; assert(runningSize > 0 && "integer overflow in size computation"); } else { dynamicPoisonBit = true; } } return simplifyAffineExpr(expr, numDims, nSymbols); } /// Return a version of `t` with a layout that has all dynamic offset and /// strides. This is used to erase the static layout. MemRefType mlir::eraseStridedLayout(MemRefType t) { auto val = ShapedType::kDynamicStrideOrOffset; return MemRefType::Builder(t).setLayout( AffineMapAttr::get(makeStridedLinearLayoutMap( SmallVector(t.getRank(), val), val, t.getContext()))); } AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, MLIRContext *context) { SmallVector exprs; exprs.reserve(sizes.size()); for (auto dim : llvm::seq(0, sizes.size())) exprs.push_back(getAffineDimExpr(dim, context)); return makeCanonicalStridedLayoutExpr(sizes, exprs, context); } /// Return true if the layout for `t` is compatible with strided semantics. bool mlir::isStrided(MemRefType t) { int64_t offset; SmallVector strides; auto res = getStridesAndOffset(t, strides, offset); return succeeded(res); } /// Return the layout map in strided linear layout AffineMap form. /// Return null if the layout is not compatible with a strided layout. AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(t, strides, offset))) return AffineMap(); return makeStridedLinearLayoutMap(strides, offset, t.getContext()); } /// Return the AffineExpr representation of the offset, assuming `memRefType` /// is a strided memref. static AffineExpr getOffsetExpr(MemRefType memrefType) { SmallVector strides; AffineExpr offset; if (failed(getStridesAndOffset(memrefType, strides, offset))) assert(false && "expected strided memref"); return offset; } /// Helper to construct a contiguous MemRefType of `shape`, `elementType` and /// `offset` AffineExpr. static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, ArrayRef shape, Type elementType, AffineExpr offset) { AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); AffineExpr contiguousRowMajor = canonical + offset; AffineMap contiguousRowMajorMap = AffineMap::inferFromExprList({contiguousRowMajor})[0]; return MemRefType::get(shape, elementType, contiguousRowMajorMap); } /// Helper determining if a memref is static-shape and contiguous-row-major /// layout, while still allowing for an arbitrary offset (any static or /// dynamic value). bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { if (!memrefType.hasStaticShape()) return false; AffineExpr offset = getOffsetExpr(memrefType); MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( memrefType.getContext(), memrefType.getShape(), memrefType.getElementType(), offset); return canonicalizeStridedLayout(memrefType) == canonicalizeStridedLayout(contiguousRowMajorMemRefType); }