1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to remove usage of unit-extent dimensions 10 // to specify broadcasting in favor of more canonical representation of the 11 // computation 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "PassDetail.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/Transforms/FoldUtils.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "llvm/Support/CommandLine.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "linalg-drop-unit-dims" 29 30 using namespace mlir; 31 using namespace mlir::edsc; 32 using namespace mlir::edsc::intrinsics; 33 using namespace mlir::linalg; 34 35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 36 /// broadcasting. For example, 37 /// 38 /// ```mlir 39 /// #accesses = [ 40 /// affine_map<(d0, d1) -> (0, d1)>, 41 /// affine_map<(d0, d1) -> (d0, 0)>, 42 /// affine_map<(d0, d1) -> (d0, d1)> 43 /// ] 44 /// 45 /// #trait = { 46 /// args_in = 2, 47 /// args_out = 1, 48 /// indexing_maps = #accesses, 49 /// iterator_types = ["parallel", "parallel"], 50 /// library_call = "some_external_fn" 51 /// } 52 /// 53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 54 /// tensor<5x5xf32> 55 /// { 56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 57 /// tensor<5xf32> into tensor<1x5xf32> 58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 59 /// tensor<5xf32> into tensor<5x1xf32> 60 /// %2 = linalg.generic #trait %0, %1 { 61 /// ^bb0(%arg2: f32, %arg3: f32): 62 /// %3 = addf %arg2, %arg3 : f32 63 /// linalg.yield %3 : f32 64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 65 /// return %2 : tensor<5x5xf32> 66 /// } 67 /// 68 /// would canonicalize to 69 /// 70 /// ```mlir 71 /// #accesses = [ 72 /// affine_map<(d0, d1) -> (d1)>, 73 /// affine_map<(d0, d1) -> (d0)>, 74 /// affine_map<(d0, d1) -> (d0, d1)> 75 /// ] 76 /// 77 /// #trait = { 78 /// args_in = 2, 79 /// args_out = 1, 80 /// indexing_maps = #accesses, 81 /// iterator_types = ["parallel", "parallel"], 82 /// library_call = "some_external_fn" 83 /// } 84 /// 85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 86 /// tensor<5x5xf32> 87 /// { 88 /// %0 = linalg.generic #trait %arg0, %arg1 { 89 /// ^bb0(%arg2: f32, %arg3: f32): 90 /// %3 = addf %arg2, %arg3 : f32 91 /// linalg.yield %3 : f32 92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 93 /// return %0 : tensor<5x5xf32> 94 /// } 95 96 /// Given dims of the iteration space of a structured op that are known to be 97 /// single trip count (`unitDims`), return the indexing maps to use in the 98 /// canonicalized op with these dims removed, given the original `indexingMaps`. 99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims, 100 ArrayRef<AffineMap> indexingMaps, 101 MLIRContext *context) { 102 if (indexingMaps.empty()) 103 return nullptr; 104 unsigned numIterationDims = indexingMaps.front().getNumDims(); 105 unsigned numSymbols = indexingMaps.front().getNumSymbols(); 106 107 // Compute the replacement for each dim expr. 108 SmallVector<AffineExpr, 4> dimReplacements; 109 dimReplacements.reserve(numIterationDims); 110 unsigned numKeptDims = 0; 111 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) { 112 if (unitDims.count(dim)) 113 dimReplacements.push_back(getAffineConstantExpr(0, context)); 114 else 115 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context)); 116 } 117 118 // Symbols remain the same. 119 SmallVector<AffineExpr, 4> symReplacements; 120 symReplacements.reserve(numSymbols); 121 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols)) 122 symReplacements.push_back(getAffineSymbolExpr(symbol, context)); 123 124 SmallVector<AffineMap, 4> newIndexingMaps; 125 newIndexingMaps.reserve(indexingMaps.size()); 126 for (AffineMap operandMap : indexingMaps) { 127 // Expected indexing maps to have no symbols. 128 if (operandMap.getNumSymbols()) 129 return nullptr; 130 newIndexingMaps.push_back(simplifyAffineMap( 131 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements, 132 numIterationDims - unitDims.size(), 133 numSymbols))); 134 } 135 136 // Check that the new index maps are invertible. If not, something went 137 // wrong, so abort. 138 if (!inversePermutation(concatAffineMaps(newIndexingMaps))) 139 return nullptr; 140 return ArrayAttr::get(context, 141 llvm::to_vector<4>(llvm::map_range( 142 newIndexingMaps, [](AffineMap map) -> Attribute { 143 return AffineMapAttr::get(map); 144 }))); 145 } 146 147 /// Modify the region of indexed generic op to drop arguments corresponding to 148 /// loops that are unit trip count. 149 template <typename OpTy> 150 static LogicalResult 151 replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims, 152 PatternRewriter &rewriterp) { 153 return success(); 154 } 155 156 template <> 157 LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>( 158 IndexedGenericOp op, const DenseSet<unsigned> &unitDims, 159 PatternRewriter &rewriter) { 160 OpBuilder::InsertionGuard guard(rewriter); 161 Block *entryBlock = &op->getRegion(0).front(); 162 rewriter.setInsertionPointToStart(entryBlock); 163 Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0); 164 for (unsigned unitDimLoop : unitDims) { 165 entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero); 166 } 167 SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end()); 168 entryBlock->eraseArguments(unitDimsToErase); 169 return success(); 170 } 171 172 namespace { 173 /// Pattern to fold unit-trip count loops in GenericOps. 174 template <typename GenericOpTy> 175 struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> { 176 using OpRewritePattern<GenericOpTy>::OpRewritePattern; 177 LogicalResult matchAndRewrite(GenericOpTy op, 178 PatternRewriter &rewriter) const override { 179 SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps(); 180 if (indexingMaps.empty()) 181 return failure(); 182 183 // Check if any of the iteration dimensions are unit-trip count. They will 184 // end up being unit-trip count if they are used to index into a unit-dim 185 // tensor/memref. 186 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); 187 if (!invertedMap) 188 return failure(); 189 SmallVector<int64_t, 4> dims; 190 for (ShapedType shapedType : op.getShapedOperandTypes()) 191 dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); 192 DenseSet<unsigned> unitDims; 193 ArrayAttr iteratorTypes = op.iterator_types(); 194 for (auto expr : enumerate(invertedMap.getResults())) { 195 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) 196 if (dims[dimExpr.getPosition()] == 1 && 197 iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() == 198 getParallelIteratorTypeName()) 199 unitDims.insert(expr.index()); 200 } 201 if (unitDims.empty()) 202 return failure(); 203 204 // Compute the modified indexing maps. 205 MLIRContext *context = rewriter.getContext(); 206 ArrayAttr newIndexingMapAttr = 207 replaceUnitDims(unitDims, indexingMaps, context); 208 if (!newIndexingMapAttr) 209 return op.emitError("unable to compute modified indexing_maps"); 210 211 // Compute the iterator types of the modified op by dropping the one-trip 212 // count loops. 213 SmallVector<Attribute, 4> newIteratorTypes; 214 for (auto attr : llvm::enumerate(iteratorTypes)) { 215 if (!unitDims.count(attr.index())) 216 newIteratorTypes.push_back(attr.value()); 217 } 218 219 rewriter.startRootUpdate(op); 220 op.indexing_mapsAttr(newIndexingMapAttr); 221 op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); 222 (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); 223 rewriter.finalizeRootUpdate(op); 224 return success(); 225 } 226 }; 227 228 struct UnitExtentReplacementInfo { 229 RankedTensorType type; 230 AffineMap indexMap; 231 ArrayAttr reassociation; 232 }; 233 } // namespace 234 235 /// Utility function for replacing operands/results to a linalg generic 236 /// operation on tensors with unit-extent dimensions. These can be replaced with 237 /// an operand/result with the unit-extent dimension removed. This is only done 238 /// if the indexing map used to access that didimensionmension has a 239 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a 240 /// Linalg op, and its `indexMap` the utility function returns: 241 /// - the new type with dimensions of size 1 removed. 242 /// - modified index map that can be used to access the replaced result/operand 243 /// - the reassociation that converts from the original tensor type to the 244 /// modified tensor type. 245 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, 246 RankedTensorType type, 247 MLIRContext *context) { 248 ArrayRef<int64_t> shape = type.getShape(); 249 ArrayRef<AffineExpr> exprs = indexMap.getResults(); 250 SmallVector<AffineExpr, 2> reassociations; 251 SmallVector<Attribute, 4> reassociationMaps; 252 SmallVector<AffineExpr, 4> newIndexExprs; 253 SmallVector<int64_t, 4> newShape; 254 255 int64_t origRank = type.getRank(); 256 AffineExpr zeroExpr = getAffineConstantExpr(0, context); 257 auto isUnitExtent = [&](int64_t dim) -> bool { 258 return shape[dim] == 1 && exprs[dim] == zeroExpr; 259 }; 260 261 unsigned dim = 0; 262 // Fold dimensions that are unit-extent at the beginning of the tensor. 263 while (dim < origRank && isUnitExtent(dim)) 264 reassociations.push_back(getAffineDimExpr(dim++, context)); 265 while (dim < origRank) { 266 reassociations.push_back(getAffineDimExpr(dim, context)); 267 newIndexExprs.push_back(exprs[dim]); 268 newShape.push_back(shape[dim]); 269 // Fold all following dimensions that are unit-extent. 270 while (dim + 1 < origRank && isUnitExtent(dim + 1)) { 271 ++dim; 272 reassociations.push_back(getAffineDimExpr(dim, context)); 273 } 274 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( 275 origRank, /*numSymbols = */ 0, reassociations, context))); 276 reassociations.clear(); 277 ++dim; 278 } 279 UnitExtentReplacementInfo info = { 280 RankedTensorType::get(newShape, type.getElementType()), 281 AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), 282 newIndexExprs, context), 283 ArrayAttr::get(context, reassociationMaps)}; 284 return info; 285 } 286 287 namespace { 288 289 /// Pattern to replace tensors operands/results that are unit extents. 290 template <typename GenericOpTy> 291 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> { 292 using OpRewritePattern<GenericOpTy>::OpRewritePattern; 293 LogicalResult matchAndRewrite(GenericOpTy op, 294 PatternRewriter &rewriter) const override { 295 // TODO: support reductions. 296 if (!op.hasTensorSemantics()) 297 return failure(); 298 299 MLIRContext *context = rewriter.getContext(); 300 Location loc = op.getLoc(); 301 302 SmallVector<AffineMap, 4> newIndexingMaps; 303 SmallVector<ArrayAttr, 4> reassociationMaps; 304 SmallVector<ShapedType, 4> newInputOutputTypes; 305 bool doCanonicalization = false; 306 for (auto it : 307 llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) { 308 auto replacementInfo = replaceUnitExtents( 309 std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(), 310 context); 311 reassociationMaps.push_back(replacementInfo.reassociation); 312 newIndexingMaps.push_back(replacementInfo.indexMap); 313 newInputOutputTypes.push_back(replacementInfo.type); 314 doCanonicalization |= replacementInfo.type != std::get<1>(it); 315 } 316 317 // If the indexing maps of the result operation are not invertible (i.e. not 318 // legal), abort. 319 if (!doCanonicalization || 320 !inversePermutation(concatAffineMaps(newIndexingMaps))) 321 return failure(); 322 323 // If any operand type change, insert a reshape to convert from the original 324 // type to the new type. 325 // TODO: get rid of flattenedIdx which assumes operand order and contiguity. 326 unsigned flattenedIdx = 0; 327 auto insertReshapes = [&](ValueRange values) { 328 SmallVector<Value, 4> res; 329 res.reserve(values.size()); 330 for (auto operand : llvm::enumerate(values)) { 331 if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) 332 res.push_back(operand.value()); 333 else 334 res.push_back(rewriter.create<linalg::TensorReshapeOp>( 335 loc, newInputOutputTypes[flattenedIdx], operand.value(), 336 reassociationMaps[flattenedIdx])); 337 ++flattenedIdx; 338 } 339 return res; 340 }; 341 342 SmallVector<Value, 4> newInputs = insertReshapes(op.inputs()); 343 SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs()); 344 345 // If any result type changes, insert a reshape to convert from the original 346 // type to the new type. 347 SmallVector<Type, 4> resultTypes; 348 resultTypes.reserve(op.getNumResults()); 349 for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults())) 350 resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]); 351 GenericOpTy replacementOp = rewriter.create<GenericOpTy>( 352 loc, resultTypes, newInputs, newOutputs, newIndexingMaps, 353 llvm::to_vector<4>( 354 op.iterator_types().template getAsValueRange<StringAttr>())); 355 rewriter.inlineRegionBefore(op.region(), replacementOp.region(), 356 replacementOp.region().begin()); 357 358 // If any result tensor has a modified shape, then add reshape to recover 359 // the original shape. 360 SmallVector<Value, 4> resultReplacements; 361 for (auto result : llvm::enumerate(replacementOp.getResults())) { 362 unsigned index = result.index() + replacementOp.getNumInputs(); 363 RankedTensorType origResultType = op.getResult(result.index()) 364 .getType() 365 .template cast<RankedTensorType>(); 366 if (origResultType != result.value().getType()) 367 resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>( 368 loc, origResultType, result.value(), reassociationMaps[index])); 369 else 370 resultReplacements.push_back(result.value()); 371 } 372 rewriter.replaceOp(op, resultReplacements); 373 return success(); 374 } 375 }; 376 377 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for 378 /// example: 379 /// 380 /// %0 = linalg.tensor_reshape %arg0 381 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] 382 /// : tensor<2048xf32> into tensor<1x4x1x512xf32> 383 /// %1 = linalg.tensor_reshape %0 384 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 385 /// affine_map<(d0, d1, d2, d3) -> (d3)>] 386 /// : tensor<1x4x1x512xf32> into tensor<4x512xf32> 387 /// 388 /// can be replaced with 389 /// 390 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] 391 /// : tensor<2048xf32> into tensor<4x512xf32> 392 /// 393 /// Similarly, 394 /// 395 /// %0 = linalg.tensor_reshape %arg0 396 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 397 /// affine_map<(d0, d1, d2, d3) -> (d3)>] 398 /// : tensor<4x512xf32> into tensor<1x4x1x512xf32> 399 /// %1 = linalg.tensor_reshape %0 400 /// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>] 401 /// : tensor<1x4x1x512xf32> into tensor<2048xf32> 402 /// 403 /// can be replaced with 404 /// 405 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] 406 /// : tensor<4x512xf32> into tensor<2048xf32> 407 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> { 408 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 409 410 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 411 PatternRewriter &rewriter) const override { 412 // Check that the source operand is created from a reshape as well. 413 TensorReshapeOp parentReshapeOp = 414 reshapeOp.src().getDefiningOp<TensorReshapeOp>(); 415 if (!parentReshapeOp) 416 return failure(); 417 418 RankedTensorType srcType = reshapeOp.getSrcType(), 419 dstType = reshapeOp.getResultType(), 420 parentSrcType = parentReshapeOp.getSrcType(); 421 if (!srcType.hasStaticShape() || !dstType.hasStaticShape() || 422 !parentSrcType.hasStaticShape() || 423 srcType.getRank() < dstType.getRank() || 424 parentSrcType.getRank() == dstType.getRank()) 425 return failure(); 426 427 // Check if the result tensor_reshape is folding or expanding after folding 428 // the reshapeOp and parentReshapeOp are combined. If the final 429 // tensor_reshape is folding, the parentReshapeOp is introducing unit-dims, 430 // and the reshapeOp does an actual reshape. If the final tensor_reshape op 431 // is expanding, the reshapeOp is introducing unit-dims, and the 432 // parentReshapeOp does an actual reshape. 433 bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); 434 ArrayRef<int64_t> expandedShape = 435 isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); 436 ArrayRef<int64_t> foldedShape = 437 isFoldingPattern ? dstType.getShape() : parentSrcType.getShape(); 438 439 unsigned expandedDim = 0, foldedDim = 0; 440 SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs( 441 foldedShape.size()); 442 while (expandedDim < expandedShape.size() && 443 foldedDim < foldedShape.size()) { 444 int64_t dstSize = foldedShape[foldedDim]; 445 int64_t srcSize = expandedShape[expandedDim]; 446 while (srcSize < dstSize && expandedDim < expandedShape.size()) { 447 reassociationExprs[foldedDim].push_back( 448 rewriter.getAffineDimExpr(expandedDim++)); 449 srcSize *= expandedShape[expandedDim]; 450 } 451 if (srcSize == dstSize) { 452 reassociationExprs[foldedDim].push_back( 453 rewriter.getAffineDimExpr(expandedDim++)); 454 // If the next dim in foldedShape is not 1, treat subsequent dims in 455 // expandedShape which are 1 to be collapsed. 456 if (foldedDim == foldedShape.size() - 1 || 457 foldedShape[foldedDim + 1] != 1) { 458 while (expandedDim < expandedShape.size() && 459 expandedShape[expandedDim] == 1) { 460 reassociationExprs[foldedDim].push_back( 461 rewriter.getAffineDimExpr(expandedDim++)); 462 } 463 } 464 } else { 465 return failure(); 466 } 467 foldedDim++; 468 } 469 if (expandedDim != expandedShape.size()) 470 return failure(); 471 472 SmallVector<AffineMap, 4> reassociationMaps = 473 llvm::to_vector<4>(llvm::map_range( 474 reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap { 475 return AffineMap::get(expandedShape.size(), 0, exprs, 476 rewriter.getContext()); 477 })); 478 rewriter.replaceOpWithNewOp<TensorReshapeOp>( 479 reshapeOp, dstType, parentReshapeOp.src(), 480 rewriter.getAffineMapArrayAttr(reassociationMaps)); 481 return success(); 482 } 483 }; 484 485 /// Pattern to fold subtensors that are just taking a slice of unit-dimension 486 /// tensor. For example 487 /// 488 /// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1] 489 /// : tensor<1x?x1xf32> to tensor<1x?x1xf32> 490 /// 491 /// can be replaced with 492 /// 493 /// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 494 /// : tensor<1x?x1xf32> into tensor<?xf32> 495 /// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32> 496 /// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] 497 /// : tensor<?xf32> into tensor<1x?x1xf32> 498 /// 499 /// The additional tensor_reshapes will hopefully get canonicalized away with 500 /// other reshapes that drop unit dimensions. Three condiitions to fold a 501 /// dimension 502 /// - The offset must be 0 503 /// - The size must be 1 504 /// - The dimension of the source type must be 1. 505 struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> { 506 using OpRewritePattern<SubTensorOp>::OpRewritePattern; 507 508 LogicalResult matchAndRewrite(SubTensorOp subTensorOp, 509 PatternRewriter &rewriter) const override { 510 SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets(); 511 SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes(); 512 SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides(); 513 auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) { 514 auto attr = valueOrAttr.dyn_cast<Attribute>(); 515 return attr && attr.cast<IntegerAttr>().getInt() == val; 516 }; 517 518 if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) { 519 return !hasValue(valueOrAttr, 1); 520 })) 521 return failure(); 522 523 // Find the expanded unit dimensions. 524 SmallVector<ReassociationIndices> reassociation; 525 SmallVector<OpFoldResult> newOffsets, newSizes; 526 ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape(); 527 ReassociationIndices curr; 528 for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) { 529 curr.push_back(dim); 530 if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) && 531 hasValue(mixedSizes[dim], 1)) { 532 continue; 533 } 534 newOffsets.push_back(mixedOffsets[dim]); 535 newSizes.push_back(mixedSizes[dim]); 536 reassociation.emplace_back(ReassociationIndices{}); 537 std::swap(reassociation.back(), curr); 538 } 539 if (newOffsets.size() == mixedOffsets.size()) 540 return failure(); 541 reassociation.back().append(curr.begin(), curr.end()); 542 SmallVector<OpFoldResult> newStrides(newOffsets.size(), 543 rewriter.getI64IntegerAttr(1)); 544 Location loc = subTensorOp->getLoc(); 545 auto srcReshape = rewriter.create<TensorReshapeOp>( 546 loc, subTensorOp.source(), reassociation); 547 auto newSubTensorOp = rewriter.create<SubTensorOp>( 548 loc, srcReshape, newOffsets, newSizes, newStrides); 549 rewriter.replaceOpWithNewOp<TensorReshapeOp>( 550 subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation); 551 return success(); 552 } 553 }; 554 555 } // namespace 556 557 /// Patterns that are used to canonicalize the use of unit-extent dims for 558 /// broadcasting. 559 void mlir::populateLinalgFoldUnitExtentDimsPatterns( 560 RewritePatternSet &patterns) { 561 auto *context = patterns.getContext(); 562 patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>, 563 FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>, 564 ReplaceUnitExtentTensors<IndexedGenericOp>>(context); 565 TensorReshapeOp::getCanonicalizationPatterns(patterns, context); 566 patterns.add<FoldReshapeOpWithUnitExtent>(context); 567 populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); 568 } 569 570 namespace { 571 /// Pass that removes unit-extent dims within generic ops. 572 struct LinalgFoldUnitExtentDimsPass 573 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> { 574 void runOnFunction() override { 575 FuncOp funcOp = getFunction(); 576 MLIRContext *context = funcOp.getContext(); 577 RewritePatternSet patterns(context); 578 if (foldOneTripLoopsOnly) 579 patterns 580 .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>( 581 context); 582 else 583 populateLinalgFoldUnitExtentDimsPatterns(patterns); 584 (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); 585 } 586 }; 587 } // namespace 588 589 std::unique_ptr<OperationPass<FuncOp>> 590 mlir::createLinalgFoldUnitExtentDimsPass() { 591 return std::make_unique<LinalgFoldUnitExtentDimsPass>(); 592 } 593