1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 10 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Builders.h" 13 14 #include <numeric> 15 16 using namespace mlir; 17 18 Optional<SmallVector<ReassociationIndices>> 19 mlir::getReassociationIndicesForReshape(ShapedType sourceType, 20 ShapedType targetType) { 21 if (sourceType.getRank() > targetType.getRank()) 22 return getReassociationIndicesForCollapse(sourceType.getShape(), 23 targetType.getShape()); 24 if (sourceType.getRank() < targetType.getRank()) 25 return getReassociationIndicesForCollapse(targetType.getShape(), 26 sourceType.getShape()); 27 return llvm::None; 28 } 29 30 Optional<SmallVector<ReassociationIndices>> 31 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, 32 ArrayRef<int64_t> targetShape) { 33 if (sourceShape.size() <= targetShape.size()) 34 return llvm::None; 35 unsigned sourceDim = 0; 36 SmallVector<ReassociationIndices> reassociationMap; 37 reassociationMap.reserve(targetShape.size()); 38 39 ReassociationIndices currIndices; 40 int64_t prodOfCollapsedDims = 1; 41 while (sourceDim < sourceShape.size()) { 42 unsigned targetDim = reassociationMap.size(); 43 // If we have mapped all the target dimensions stop and handle the remaining 44 // tail of size-1 dimensions explictly. 45 if (targetDim == targetShape.size()) 46 break; 47 48 int64_t currTargetShape = targetShape[targetDim]; 49 while (sourceShape[sourceDim] != ShapedType::kDynamicSize && 50 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && 51 sourceDim < sourceShape.size()) { 52 prodOfCollapsedDims *= sourceShape[sourceDim]; 53 currIndices.push_back(sourceDim++); 54 } 55 56 // If the current expanded dimension is dynamic, then the collapsed 57 // dimensions should also be dynamic and product of all previous unprocessed 58 // dimensions of the expanded shape should be 1. 59 if (sourceShape[sourceDim] == ShapedType::kDynamicSize && 60 (currTargetShape != ShapedType::kDynamicSize || 61 prodOfCollapsedDims != 1)) 62 return llvm::None; 63 64 // If the collapsed dim is dynamic, the current expanded dim should also 65 // be dynamic. 66 if (currTargetShape == ShapedType::kDynamicSize && 67 sourceShape[sourceDim] != ShapedType::kDynamicSize) 68 return llvm::None; 69 70 // For static shapes, if the product of dimensions of the expanded shape 71 // should match the collapsed dimension shape. 72 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) 73 return llvm::None; 74 75 currIndices.push_back(sourceDim++); 76 reassociationMap.emplace_back(ReassociationIndices{}); 77 std::swap(reassociationMap.back(), currIndices); 78 prodOfCollapsedDims = 1; 79 } 80 // All the dimensions in the target must have been processed. 81 if (reassociationMap.size() != targetShape.size()) 82 return llvm::None; 83 // Process any remaining entries in the source shape. They all need to be 84 // 1 or dynamic. 85 for (; sourceDim < sourceShape.size(); sourceDim++) { 86 if (sourceShape[sourceDim] != ShapedType::kDynamicSize && 87 sourceShape[sourceDim] != 1) 88 return llvm::None; 89 // The map is empty when the target type is a scalar. 90 if (!reassociationMap.empty()) 91 reassociationMap.back().push_back(sourceDim); 92 } 93 return reassociationMap; 94 } 95 96 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices( 97 ArrayRef<ReassociationIndices> producerReassociations, 98 ArrayRef<ReassociationIndices> consumerReassociations, 99 MLIRContext *context) { 100 SmallVector<ReassociationIndices> composedIndices; 101 // Make the producer the larger sized vector. If they are of same size, the 102 // resulting reshape is not a supported reshape op. 103 if (producerReassociations.size() == consumerReassociations.size()) 104 return llvm::None; 105 if (producerReassociations.size() < consumerReassociations.size()) 106 std::swap(producerReassociations, consumerReassociations); 107 108 // Handle the corner case of the result being a rank 0 shaped type. Return an 109 // empty reassociation. 110 if (consumerReassociations.empty()) 111 return composedIndices; 112 113 size_t consumerDims = std::accumulate( 114 consumerReassociations.begin(), consumerReassociations.end(), 0, 115 [](size_t all, ReassociationIndicesRef indices) { 116 return all + indices.size(); 117 }); 118 if (producerReassociations.size() != consumerDims) 119 return llvm::None; 120 121 for (ReassociationIndicesRef consumerIndices : consumerReassociations) { 122 ReassociationIndices reassociations; 123 for (int64_t consumerIndex : consumerIndices) { 124 llvm::append_range(reassociations, producerReassociations[consumerIndex]); 125 } 126 composedIndices.push_back(std::move(reassociations)); 127 } 128 return composedIndices; 129 } 130 131 SmallVector<SmallVector<AffineExpr, 2>, 2> 132 mlir::convertReassociationIndicesToExprs( 133 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) { 134 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; 135 for (const auto &indices : reassociationIndices) { 136 SmallVector<AffineExpr, 2> reassociationMap; 137 reassociationMap.reserve(indices.size()); 138 for (int64_t index : indices) 139 reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); 140 reassociationMaps.push_back(std::move(reassociationMap)); 141 } 142 return reassociationMaps; 143 } 144 145 template <typename AffineExprTy> 146 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { 147 unsigned pos = 0; 148 for (const auto &exprs : exprArrays) { 149 for (auto expr : exprs) { 150 expr.walk([&pos](AffineExpr e) { 151 if (auto d = e.dyn_cast<AffineExprTy>()) 152 pos = std::max(pos, d.getPosition()); 153 }); 154 } 155 } 156 return pos; 157 } 158 159 ArrayAttr mlir::getReassociationIndicesAttribute( 160 OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) { 161 SmallVector<Attribute, 4> reassociationAttr = 162 llvm::to_vector<4>(llvm::map_range( 163 reassociation, [&](const ReassociationIndices &indices) -> Attribute { 164 return b.getI64ArrayAttr(indices).cast<Attribute>(); 165 })); 166 return b.getArrayAttr(reassociationAttr); 167 } 168 169 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices( 170 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) { 171 SmallVector<ReassociationIndices, 2> reassociationIndices; 172 for (const auto &exprs : reassociationExprs) { 173 ReassociationIndices indices; 174 indices.reserve(exprs.size()); 175 for (const auto &expr : exprs) 176 indices.push_back(expr.cast<AffineDimExpr>().getPosition()); 177 reassociationIndices.push_back(indices); 178 } 179 return reassociationIndices; 180 } 181 182 SmallVector<AffineMap, 4> 183 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { 184 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); 185 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && 186 "Expected symbol-less expressions"); 187 SmallVector<AffineMap, 4> maps; 188 maps.reserve(reassociation.size()); 189 for (const auto &exprs : reassociation) { 190 assert(!exprs.empty()); 191 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); 192 } 193 return maps; 194 } 195 196 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation, 197 int *invalidIndex) { 198 if (reassociation.empty()) 199 return true; 200 unsigned nDims = reassociation[0].getNumDims(); 201 unsigned nextExpectedDim = 0; 202 for (const auto &it : llvm::enumerate(reassociation)) { 203 auto m = it.value(); 204 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { 205 if (invalidIndex) 206 *invalidIndex = it.index(); 207 return false; 208 } 209 for (auto e : m.getResults()) { 210 auto d = e.dyn_cast<AffineDimExpr>(); 211 if (!d || d.getPosition() != nextExpectedDim++) { 212 if (invalidIndex) 213 *invalidIndex = it.index(); 214 return false; 215 } 216 } 217 } 218 if (nextExpectedDim != nDims) { 219 if (invalidIndex) 220 *invalidIndex = reassociation.size() - 1; 221 return false; 222 } 223 return true; 224 } 225 226 LogicalResult mlir::reshapeLikeShapesAreCompatible( 227 function_ref<LogicalResult(const Twine &)> emitError, 228 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, 229 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) { 230 unsigned expandedDimStart = 0; 231 for (const auto &map : llvm::enumerate(reassociationMaps)) { 232 Optional<int64_t> dynamicShape; 233 int64_t linearizedStaticShape = 1; 234 for (const auto &dim : llvm::enumerate( 235 expandedShape.slice(expandedDimStart, map.value().size()))) { 236 if (ShapedType::isDynamic(dim.value())) { 237 if (isExpandingReshape && dynamicShape) { 238 return emitError("invalid to have a single dimension (" + 239 Twine(map.index()) + 240 ") expanded into multiple dynamic dims (" + 241 Twine(expandedDimStart + dynamicShape.getValue()) + 242 "," + Twine(expandedDimStart + dim.index()) + ")"); 243 } 244 dynamicShape = dim.index(); 245 } else { 246 linearizedStaticShape *= dim.value(); 247 } 248 } 249 if (dynamicShape) { 250 if (!ShapedType::isDynamic(collapsedShape[map.index()])) { 251 return emitError( 252 "expected dimension " + Twine(map.index()) + 253 " of collapsed type to be dynamic since one or more of the " 254 "corresponding dimensions in the expanded type is dynamic"); 255 } 256 } else { 257 if (collapsedShape[map.index()] != linearizedStaticShape) { 258 return emitError("expected dimension " + Twine(map.index()) + 259 " of collapsed type to be static value of " + 260 Twine(linearizedStaticShape)); 261 } 262 } 263 expandedDimStart += map.value().size(); 264 } 265 return success(); 266 } 267 268 bool mlir::hasNonIdentityLayout(Type type) { 269 if (auto memrefType = type.dyn_cast<MemRefType>()) 270 return !memrefType.getLayout().isIdentity(); 271 return false; 272 } 273