1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 10 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Builders.h" 13 14 #include <numeric> 15 16 using namespace mlir; 17 18 Optional<SmallVector<ReassociationIndices>> 19 mlir::getReassociationIndicesForReshape(ShapedType sourceType, 20 ShapedType targetType) { 21 // Make the sourceType greater rank than the targetType. If they are same 22 // rank, then its an unsupported reshape op. 23 if (sourceType.getRank() == targetType.getRank()) 24 return llvm::None; 25 if (sourceType.getRank() < targetType.getRank()) 26 std::swap(sourceType, targetType); 27 28 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 29 ArrayRef<int64_t> targetShape = targetType.getShape(); 30 unsigned sourceDim = 0; 31 SmallVector<ReassociationIndices> reassociationMap; 32 reassociationMap.reserve(targetType.getRank()); 33 34 ReassociationIndices currIndices; 35 int64_t prodOfCollapsedDims = 1; 36 while (sourceDim < sourceShape.size()) { 37 unsigned targetDim = reassociationMap.size(); 38 // If we have mapped all the target dimensions stop and handle the remaining 39 // tail of size-1 dimensions explictly. 40 if (targetDim == targetType.getRank()) 41 break; 42 43 int64_t currTargetShape = targetShape[targetDim]; 44 while (sourceShape[sourceDim] != ShapedType::kDynamicSize && 45 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && 46 sourceDim < sourceShape.size()) { 47 prodOfCollapsedDims *= sourceShape[sourceDim]; 48 currIndices.push_back(sourceDim++); 49 } 50 51 // If the current expanded dimension is dynamic, then the collapsed 52 // dimensions should also be dynamic and product of all previous unprocessed 53 // dimensions of the expanded shape should be 1. 54 if (sourceShape[sourceDim] == ShapedType::kDynamicSize && 55 (currTargetShape != ShapedType::kDynamicSize || 56 prodOfCollapsedDims != 1)) 57 return llvm::None; 58 59 // If the collapsed dim is dynamic, the current expanded dim should also 60 // be dynamic. 61 if (currTargetShape == ShapedType::kDynamicSize && 62 sourceShape[sourceDim] != ShapedType::kDynamicSize) 63 return llvm::None; 64 65 // For static shapes, if the product of dimensions of the expanded shape 66 // should match the collapsed dimension shape. 67 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) 68 return llvm::None; 69 70 currIndices.push_back(sourceDim++); 71 reassociationMap.emplace_back(ReassociationIndices{}); 72 std::swap(reassociationMap.back(), currIndices); 73 prodOfCollapsedDims = 1; 74 } 75 // All the dimensions in the target must have been processed. 76 if (reassociationMap.size() != targetShape.size()) 77 return llvm::None; 78 // Process any remaining entries in the source shape. They all need to be 79 // 1 or dynamic. 80 for (; sourceDim < sourceShape.size(); sourceDim++) { 81 if (sourceShape[sourceDim] != ShapedType::kDynamicSize && 82 sourceShape[sourceDim] != 1) 83 return llvm::None; 84 // The map is empty when the target type is a scalar. 85 if (!reassociationMap.empty()) 86 reassociationMap.back().push_back(sourceDim); 87 } 88 return reassociationMap; 89 } 90 91 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices( 92 ArrayRef<ReassociationIndices> producerReassociations, 93 ArrayRef<ReassociationIndices> consumerReassociations, 94 MLIRContext *context) { 95 SmallVector<ReassociationIndices> composedIndices; 96 // Make the producer the larger sized vector. If they are of same size, the 97 // resulting reshape is not a supported reshape op. 98 if (producerReassociations.size() == consumerReassociations.size()) 99 return llvm::None; 100 if (producerReassociations.size() < consumerReassociations.size()) 101 std::swap(producerReassociations, consumerReassociations); 102 103 // Handle the corner case of the result being a rank 0 shaped type. Return an 104 // empty reassociation. 105 if (consumerReassociations.empty()) 106 return composedIndices; 107 108 size_t consumerDims = std::accumulate( 109 consumerReassociations.begin(), consumerReassociations.end(), 0, 110 [](size_t all, ReassociationIndicesRef indices) { 111 return all + indices.size(); 112 }); 113 if (producerReassociations.size() != consumerDims) 114 return llvm::None; 115 116 for (ReassociationIndicesRef consumerIndices : consumerReassociations) { 117 ReassociationIndices reassociations; 118 for (int64_t consumerIndex : consumerIndices) { 119 for (int64_t producerIndex : producerReassociations[consumerIndex]) 120 reassociations.push_back(producerIndex); 121 } 122 composedIndices.push_back(std::move(reassociations)); 123 } 124 return composedIndices; 125 } 126 127 SmallVector<SmallVector<AffineExpr, 2>, 2> 128 mlir::convertReassociationIndicesToExprs( 129 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) { 130 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; 131 for (const auto &indices : reassociationIndices) { 132 SmallVector<AffineExpr, 2> reassociationMap; 133 reassociationMap.reserve(indices.size()); 134 for (int64_t index : indices) 135 reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); 136 reassociationMaps.push_back(std::move(reassociationMap)); 137 } 138 return reassociationMaps; 139 } 140 141 template <typename AffineExprTy> 142 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { 143 unsigned pos = 0; 144 for (const auto &exprs : exprArrays) { 145 for (auto expr : exprs) { 146 expr.walk([&pos](AffineExpr e) { 147 if (auto d = e.dyn_cast<AffineExprTy>()) 148 pos = std::max(pos, d.getPosition()); 149 }); 150 } 151 } 152 return pos; 153 } 154 155 ArrayAttr mlir::getReassociationIndicesAttribute( 156 OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) { 157 SmallVector<Attribute, 4> reassociationAttr = 158 llvm::to_vector<4>(llvm::map_range( 159 reassociation, [&](const ReassociationIndices &indices) -> Attribute { 160 return b.getI64ArrayAttr(indices).cast<Attribute>(); 161 })); 162 return b.getArrayAttr(reassociationAttr); 163 } 164 165 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices( 166 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) { 167 SmallVector<ReassociationIndices, 2> reassociationIndices; 168 for (const auto &exprs : reassociationExprs) { 169 ReassociationIndices indices; 170 indices.reserve(exprs.size()); 171 for (const auto &expr : exprs) 172 indices.push_back(expr.cast<AffineDimExpr>().getPosition()); 173 reassociationIndices.push_back(indices); 174 } 175 return reassociationIndices; 176 } 177 178 SmallVector<AffineMap, 4> 179 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { 180 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); 181 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && 182 "Expected symbol-less expressions"); 183 SmallVector<AffineMap, 4> maps; 184 maps.reserve(reassociation.size()); 185 for (const auto &exprs : reassociation) { 186 assert(!exprs.empty()); 187 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); 188 } 189 return maps; 190 } 191 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation, 192 int *invalidIndex) { 193 if (reassociation.empty()) 194 return true; 195 unsigned nDims = reassociation[0].getNumDims(); 196 unsigned nextExpectedDim = 0; 197 for (const auto &it : llvm::enumerate(reassociation)) { 198 auto m = it.value(); 199 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { 200 if (invalidIndex) 201 *invalidIndex = it.index(); 202 return false; 203 } 204 for (auto e : m.getResults()) { 205 auto d = e.dyn_cast<AffineDimExpr>(); 206 if (!d || d.getPosition() != nextExpectedDim++) { 207 if (invalidIndex) 208 *invalidIndex = it.index(); 209 return false; 210 } 211 } 212 } 213 if (nextExpectedDim != nDims) { 214 if (invalidIndex) 215 *invalidIndex = reassociation.size() - 1; 216 return false; 217 } 218 return true; 219 } 220 221 LogicalResult mlir::reshapeLikeShapesAreCompatible( 222 function_ref<LogicalResult(const Twine &)> emitError, 223 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, 224 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) { 225 unsigned expandedDimStart = 0; 226 for (const auto &map : llvm::enumerate(reassociationMaps)) { 227 Optional<int64_t> dynamicShape; 228 int64_t linearizedStaticShape = 1; 229 for (const auto &dim : llvm::enumerate( 230 expandedShape.slice(expandedDimStart, map.value().size()))) { 231 if (ShapedType::isDynamic(dim.value())) { 232 if (isExpandingReshape && dynamicShape) { 233 return emitError("invalid to have a single dimension (" + 234 Twine(map.index()) + 235 ") expanded into multiple dynamic dims (" + 236 Twine(expandedDimStart + dynamicShape.getValue()) + 237 "," + Twine(expandedDimStart + dim.index()) + ")"); 238 } 239 dynamicShape = dim.index(); 240 } else { 241 linearizedStaticShape *= dim.value(); 242 } 243 } 244 if (dynamicShape) { 245 if (!ShapedType::isDynamic(collapsedShape[map.index()])) { 246 return emitError( 247 "expected dimension " + Twine(map.index()) + 248 " of collapsed type to be dynamic since one or more of the " 249 "corresponding dimensions in the expanded type is dynamic"); 250 } 251 } else { 252 if (collapsedShape[map.index()] != linearizedStaticShape) { 253 return emitError("expected dimension " + Twine(map.index()) + 254 " of collapsed type to be static value of " + 255 Twine(linearizedStaticShape)); 256 } 257 } 258 expandedDimStart += map.value().size(); 259 } 260 return success(); 261 } 262