1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 10 11 #include "mlir/IR/AffineMap.h" 12 #include "mlir/IR/Builders.h" 13 14 #include <numeric> 15 16 using namespace mlir; 17 18 Optional<SmallVector<ReassociationIndices>> 19 mlir::getReassociationIndicesForReshape(ShapedType sourceType, 20 ShapedType targetType) { 21 // Make the sourceType greater rank than the targetType. If they are same 22 // rank, then its an unsupported reshape op. 23 if (sourceType.getRank() == targetType.getRank()) 24 return llvm::None; 25 if (sourceType.getRank() < targetType.getRank()) 26 std::swap(sourceType, targetType); 27 28 ArrayRef<int64_t> sourceShape = sourceType.getShape(); 29 ArrayRef<int64_t> targetShape = targetType.getShape(); 30 unsigned sourceDim = 0; 31 SmallVector<ReassociationIndices> reassociationMap; 32 reassociationMap.reserve(targetType.getRank()); 33 34 ReassociationIndices currIndices; 35 int64_t prodOfCollapsedDims = 1; 36 while (sourceDim < sourceShape.size()) { 37 unsigned targetDim = reassociationMap.size(); 38 39 // If all the dimensions of the targetShape are exhausted, then the 40 // remaining dims in the source shape must be all 1s. So for such cases, set 41 // 1 as the target shape. The actual reassociation indices will be handled 42 // later. 43 int64_t currTargetShape = 44 (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); 45 while (sourceShape[sourceDim] != ShapedType::kDynamicSize && 46 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && 47 sourceDim < sourceShape.size()) { 48 prodOfCollapsedDims *= sourceShape[sourceDim]; 49 currIndices.push_back(sourceDim++); 50 } 51 52 // If the current expanded dimension is dynamic, then the collapsed 53 // dimensions should also be dynamic and product of all previous unprocessed 54 // dimensions of the expanded shape should be 1. 55 if (sourceShape[sourceDim] == ShapedType::kDynamicSize && 56 (currTargetShape != ShapedType::kDynamicSize || 57 prodOfCollapsedDims != 1)) 58 return llvm::None; 59 60 // If the collapsed dim is dynamic, the current expanded dim should also 61 // be dynamic. 62 if (currTargetShape == ShapedType::kDynamicSize && 63 sourceShape[sourceDim] != ShapedType::kDynamicSize) 64 return llvm::None; 65 66 // For static shapes, if the product of dimensions of the expanded shape 67 // should match the collapsed dimension shape. 68 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) 69 return llvm::None; 70 71 currIndices.push_back(sourceDim++); 72 // If the reassociation is empty but the currIndices is not, this by 73 // definition is folding unit-dimensions with the result being scalar type. 74 // So only append the `currIndices` if reassociation map is not empty. 75 if (targetDim == targetShape.size()) { 76 while (sourceDim < sourceShape.size()) 77 currIndices.push_back(sourceDim++); 78 if (!reassociationMap.empty() && !currIndices.empty()) 79 reassociationMap.back().append(currIndices.begin(), currIndices.end()); 80 // Break out of the loops. We should be done here. 81 break; 82 } 83 reassociationMap.emplace_back(ReassociationIndices{}); 84 std::swap(reassociationMap.back(), currIndices); 85 prodOfCollapsedDims = 1; 86 } 87 // All the dimensions in the two shapes must have been processed. 88 if (reassociationMap.size() != targetShape.size() || 89 sourceDim != sourceShape.size()) 90 return llvm::None; 91 return reassociationMap; 92 } 93 94 ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser, 95 OperationState &result) { 96 // Parse the operand. 97 OpAsmParser::OperandType src; 98 if (parser.parseOperand(src)) 99 return failure(); 100 101 // Parse reassociation indices. 102 Builder &b = parser.getBuilder(); 103 SmallVector<Attribute, 4> reassociation; 104 if (parser.parseLSquare()) 105 return failure(); 106 107 while (true) { 108 if (succeeded(parser.parseOptionalRSquare())) 109 break; 110 if (parser.parseLSquare()) 111 return failure(); 112 SmallVector<int64_t> indices; 113 while (true) { 114 int64_t index; 115 if (parser.parseInteger(index)) 116 return failure(); 117 indices.push_back(index); 118 119 if (succeeded(parser.parseOptionalComma())) 120 continue; 121 if (failed(parser.parseRSquare())) 122 return failure(); 123 break; 124 } 125 reassociation.push_back(b.getI64ArrayAttr(indices)); 126 if (succeeded(parser.parseOptionalComma())) 127 continue; 128 if (failed(parser.parseRSquare())) 129 return failure(); 130 break; 131 } 132 133 result.addAttribute(getReassociationAttrName(), 134 b.getArrayAttr(reassociation)); 135 136 // Parse optional attributes. 137 parser.parseOptionalAttrDict(result.attributes); 138 139 // Parse types. 140 Type srcType; 141 Type resultType; 142 if (parser.parseColon() || parser.parseType(srcType) || 143 parser.resolveOperand(src, srcType, result.operands) || 144 parser.parseKeyword("into") || parser.parseType(resultType)) 145 return failure(); 146 result.addTypes(resultType); 147 return success(); 148 } 149 150 Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices( 151 ArrayRef<ReassociationIndices> producerReassociations, 152 ArrayRef<ReassociationIndices> consumerReassociations, 153 MLIRContext *context) { 154 SmallVector<ReassociationIndices> composedIndices; 155 // Make the producer the larger sized vector. If they are of same size, the 156 // resulting reshape is not a supported reshape op. 157 if (producerReassociations.size() == consumerReassociations.size()) 158 return llvm::None; 159 if (producerReassociations.size() < consumerReassociations.size()) 160 std::swap(producerReassociations, consumerReassociations); 161 162 // Handle the corner case of the result being a rank 0 shaped type. Return an 163 // empty reassociation. 164 if (consumerReassociations.empty()) 165 return composedIndices; 166 167 size_t consumerDims = std::accumulate( 168 consumerReassociations.begin(), consumerReassociations.end(), 0, 169 [](size_t all, ReassociationIndicesRef indices) { 170 return all + indices.size(); 171 }); 172 if (producerReassociations.size() != consumerDims) 173 return llvm::None; 174 175 for (ReassociationIndicesRef consumerIndices : consumerReassociations) { 176 ReassociationIndices reassociations; 177 for (int64_t consumerIndex : consumerIndices) { 178 for (int64_t producerIndex : producerReassociations[consumerIndex]) 179 reassociations.push_back(producerIndex); 180 } 181 composedIndices.push_back(std::move(reassociations)); 182 } 183 return composedIndices; 184 } 185 186 SmallVector<SmallVector<AffineExpr, 2>, 2> 187 mlir::convertReassociationIndicesToExprs( 188 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) { 189 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; 190 for (const auto &indices : reassociationIndices) { 191 SmallVector<AffineExpr, 2> reassociationMap; 192 reassociationMap.reserve(indices.size()); 193 for (int64_t index : indices) 194 reassociationMap.push_back(mlir::getAffineDimExpr(index, context)); 195 reassociationMaps.push_back(std::move(reassociationMap)); 196 } 197 return reassociationMaps; 198 } 199 200 template <typename AffineExprTy> 201 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { 202 unsigned pos = 0; 203 for (const auto &exprs : exprArrays) { 204 for (auto expr : exprs) { 205 expr.walk([&pos](AffineExpr e) { 206 if (auto d = e.dyn_cast<AffineExprTy>()) 207 pos = std::max(pos, d.getPosition()); 208 }); 209 } 210 } 211 return pos; 212 } 213 214 ArrayAttr mlir::getReassociationIndicesAttribute( 215 OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) { 216 SmallVector<Attribute, 4> reassociationAttr = 217 llvm::to_vector<4>(llvm::map_range( 218 reassociation, [&](ReassociationIndices indices) -> Attribute { 219 return b.getI64ArrayAttr(indices).cast<Attribute>(); 220 })); 221 return b.getArrayAttr(reassociationAttr); 222 } 223 224 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices( 225 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) { 226 SmallVector<ReassociationIndices, 2> reassociationIndices; 227 for (const auto &exprs : reassociationExprs) { 228 ReassociationIndices indices; 229 indices.reserve(exprs.size()); 230 for (const auto &expr : exprs) 231 indices.push_back(expr.cast<AffineDimExpr>().getPosition()); 232 reassociationIndices.push_back(indices); 233 } 234 return reassociationIndices; 235 } 236 237 SmallVector<AffineMap, 4> 238 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { 239 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); 240 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && 241 "Expected symbol-less expressions"); 242 SmallVector<AffineMap, 4> maps; 243 maps.reserve(reassociation.size()); 244 for (const auto &exprs : reassociation) { 245 assert(!exprs.empty()); 246 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); 247 } 248 return maps; 249 } 250 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation, 251 int *invalidIndex) { 252 if (reassociation.empty()) 253 return true; 254 unsigned nDims = reassociation[0].getNumDims(); 255 unsigned nextExpectedDim = 0; 256 for (auto it : llvm::enumerate(reassociation)) { 257 auto m = it.value(); 258 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { 259 if (invalidIndex) 260 *invalidIndex = it.index(); 261 return false; 262 } 263 for (auto e : m.getResults()) { 264 auto d = e.dyn_cast<AffineDimExpr>(); 265 if (!d || d.getPosition() != nextExpectedDim++) { 266 if (invalidIndex) 267 *invalidIndex = it.index(); 268 return false; 269 } 270 } 271 } 272 if (nextExpectedDim != nDims) { 273 if (invalidIndex) 274 *invalidIndex = reassociation.size() - 1; 275 return false; 276 } 277 return true; 278 } 279