1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 10 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Builders.h" 13 14 #include <numeric> 15 16 using namespace mlir; 17 18 Optional<SmallVector<ReassociationIndices>> 19 mlir::getReassociationIndicesForReshape(ShapedType sourceType, 20 ShapedType targetType) { 21 // Make the sourceType greater rank than the targetType. If they are same 22 // rank, then its an unsupported reshape op. 23 if (sourceType.getRank() == targetType.getRank()) 24 return llvm::None; 25 if (sourceType.getRank() < targetType.getRank()) 26 std::swap(sourceType, targetType); 27 28 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 29 ArrayRef<int64_t> targetShape = targetType.getShape(); 30 unsigned sourceDim = 0; 31 SmallVector<ReassociationIndices> reassociationMap; 32 reassociationMap.reserve(targetType.getRank()); 33 34 ReassociationIndices currIndices; 35 int64_t prodOfCollapsedDims = 1; 36 while (sourceDim < sourceShape.size()) { 37 unsigned targetDim = reassociationMap.size(); 38 39 // If all the dimensions of the targetShape are exhausted, then the 40 // remaining dims in the source shape must be all 1s. So for such cases, set 41 // 1 as the target shape. The actual reassociation indices will be handled 42 // later. 43 int64_t currTargetShape = 44 (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); 45 while (sourceShape[sourceDim] != ShapedType::kDynamicSize && 46 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && 47 sourceDim < sourceShape.size()) { 48 prodOfCollapsedDims *= sourceShape[sourceDim]; 49 currIndices.push_back(sourceDim++); 50 } 51 52 // If the current expanded dimension is dynamic, then the collapsed 53 // dimensions should also be dynamic and product of all previous unprocessed 54 // dimensions of the expanded shape should be 1. 55 if (sourceShape[sourceDim] == ShapedType::kDynamicSize && 56 (currTargetShape != ShapedType::kDynamicSize || 57 prodOfCollapsedDims != 1)) 58 return llvm::None; 59 60 // If the collapsed dim is dynamic, the current expanded dim should also 61 // be dynamic. 62 if (currTargetShape == ShapedType::kDynamicSize && 63 sourceShape[sourceDim] != ShapedType::kDynamicSize) 64 return llvm::None; 65 66 // For static shapes, if the product of dimensions of the expanded shape 67 // should match the collapsed dimension shape. 68 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) 69 return llvm::None; 70 71 currIndices.push_back(sourceDim++); 72 // If the reassociation is empty but the currIndices is not, this by 73 // definition is folding unit-dimensions with the result being scalar type. 74 // So only append the `currIndices` if reassociation map is not empty. 75 if (targetDim == targetShape.size()) { 76 if (!reassociationMap.empty() && !currIndices.empty()) 77 reassociationMap.back().append(currIndices.begin(), currIndices.end()); 78 // Break out of the loops. We should be done here. 79 break; 80 } 81 reassociationMap.emplace_back(ReassociationIndices{}); 82 std::swap(reassociationMap.back(), currIndices); 83 prodOfCollapsedDims = 1; 84 } 85 // All the dimensions in the two shapes must have been processed. 86 if (reassociationMap.size() != targetShape.size() || 87 sourceDim != sourceShape.size()) 88 return llvm::None; 89 return reassociationMap; 90 } 91 92 ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser, 93 OperationState &result) { 94 // Parse the operand. 95 OpAsmParser::OperandType src; 96 if (parser.parseOperand(src)) 97 return failure(); 98 99 // Parse reassociation indices. 100 Builder &b = parser.getBuilder(); 101 SmallVector<Attribute, 4> reassociation; 102 if (parser.parseLSquare()) 103 return failure(); 104 105 while (true) { 106 if (succeeded(parser.parseOptionalRSquare())) 107 break; 108 if (parser.parseLSquare()) 109 return failure(); 110 SmallVector<int64_t> indices; 111 while (true) { 112 int64_t index; 113 if (parser.parseInteger(index)) 114 return failure(); 115 indices.push_back(index); 116 117 if (succeeded(parser.parseOptionalComma())) 118 continue; 119 if (failed(parser.parseRSquare())) 120 return failure(); 121 break; 122 } 123 reassociation.push_back(b.getI64ArrayAttr(indices)); 124 if (succeeded(parser.parseOptionalComma())) 125 continue; 126 if (failed(parser.parseRSquare())) 127 return failure(); 128 break; 129 } 130 131 result.addAttribute(getReassociationAttrName(), 132 b.getArrayAttr(reassociation)); 133 134 // Parse optional attributes. 135 parser.parseOptionalAttrDict(result.attributes); 136 137 // Parse types. 138 Type srcType; 139 Type resultType; 140 if (parser.parseColon() || parser.parseType(srcType) || 141 parser.resolveOperand(src, srcType, result.operands) || 142 parser.parseKeyword("into") || parser.parseType(resultType)) 143 return failure(); 144 result.addTypes(resultType); 145 return success(); 146 } 147 148 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices( 149 ArrayRef<ReassociationIndices> producerReassociations, 150 ArrayRef<ReassociationIndices> consumerReassociations, 151 MLIRContext *context) { 152 SmallVector<ReassociationIndices> composedIndices; 153 // Make the producer the larger sized vector. If they are of same size, the 154 // resulting reshape is not a supported reshape op. 155 if (producerReassociations.size() == consumerReassociations.size()) 156 return llvm::None; 157 if (producerReassociations.size() < consumerReassociations.size()) 158 std::swap(producerReassociations, consumerReassociations); 159 160 // Handle the corner case of the result being a rank 0 shaped type. Return an 161 // empty reassociation. 162 if (consumerReassociations.empty()) 163 return composedIndices; 164 165 size_t consumerDims = std::accumulate( 166 consumerReassociations.begin(), consumerReassociations.end(), 0, 167 [](size_t all, ReassociationIndicesRef indices) { 168 return all + indices.size(); 169 }); 170 if (producerReassociations.size() != consumerDims) 171 return llvm::None; 172 173 for (ReassociationIndicesRef consumerIndices : consumerReassociations) { 174 ReassociationIndices reassociations; 175 for (int64_t consumerIndex : consumerIndices) { 176 for (int64_t producerIndex : producerReassociations[consumerIndex]) 177 reassociations.push_back(producerIndex); 178 } 179 composedIndices.push_back(std::move(reassociations)); 180 } 181 return composedIndices; 182 } 183 184 SmallVector<SmallVector<AffineExpr, 2>, 2> 185 mlir::convertReassociationIndicesToExprs( 186 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) { 187 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; 188 for (const auto &indices : reassociationIndices) { 189 SmallVector<AffineExpr, 2> reassociationMap; 190 reassociationMap.reserve(indices.size()); 191 for (int64_t index : indices) 192 reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); 193 reassociationMaps.push_back(std::move(reassociationMap)); 194 } 195 return reassociationMaps; 196 } 197 198 template <typename AffineExprTy> 199 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { 200 unsigned pos = 0; 201 for (const auto &exprs : exprArrays) { 202 for (auto expr : exprs) { 203 expr.walk([&pos](AffineExpr e) { 204 if (auto d = e.dyn_cast<AffineExprTy>()) 205 pos = std::max(pos, d.getPosition()); 206 }); 207 } 208 } 209 return pos; 210 } 211 212 ArrayAttr mlir::getReassociationIndicesAttribute( 213 OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) { 214 SmallVector<Attribute, 4> reassociationAttr = 215 llvm::to_vector<4>(llvm::map_range( 216 reassociation, [&](ReassociationIndices indices) -> Attribute { 217 return b.getI64ArrayAttr(indices).cast<Attribute>(); 218 })); 219 return b.getArrayAttr(reassociationAttr); 220 } 221 222 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices( 223 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) { 224 SmallVector<ReassociationIndices, 2> reassociationIndices; 225 for (const auto &exprs : reassociationExprs) { 226 ReassociationIndices indices; 227 indices.reserve(exprs.size()); 228 for (const auto &expr : exprs) 229 indices.push_back(expr.cast<AffineDimExpr>().getPosition()); 230 reassociationIndices.push_back(indices); 231 } 232 return reassociationIndices; 233 } 234 235 SmallVector<AffineMap, 4> 236 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { 237 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); 238 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && 239 "Expected symbol-less expressions"); 240 SmallVector<AffineMap, 4> maps; 241 maps.reserve(reassociation.size()); 242 for (const auto &exprs : reassociation) { 243 assert(!exprs.empty()); 244 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); 245 } 246 return maps; 247 } 248 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation, 249 int *invalidIndex) { 250 if (reassociation.empty()) 251 return true; 252 unsigned nDims = reassociation[0].getNumDims(); 253 unsigned nextExpectedDim = 0; 254 for (auto it : llvm::enumerate(reassociation)) { 255 auto m = it.value(); 256 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { 257 if (invalidIndex) 258 *invalidIndex = it.index(); 259 return false; 260 } 261 for (auto e : m.getResults()) { 262 auto d = e.dyn_cast<AffineDimExpr>(); 263 if (!d || d.getPosition() != nextExpectedDim++) { 264 if (invalidIndex) 265 *invalidIndex = it.index(); 266 return false; 267 } 268 } 269 } 270 if (nextExpectedDim != nDims) { 271 if (invalidIndex) 272 *invalidIndex = reassociation.size() - 1; 273 return false; 274 } 275 return true; 276 } 277