1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "linalg-drop-unit-dims"
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
35 /// broadcasting. For example,
36 ///
37 /// ```mlir
38 /// #accesses = [
39 ///   affine_map<(d0, d1) -> (0, d1)>,
40 ///   affine_map<(d0, d1) -> (d0, 0)>,
41 ///   affine_map<(d0, d1) -> (d0, d1)>
42 /// ]
43 ///
44 /// #trait = {
45 ///   args_in = 2,
46 ///   args_out = 1,
47 ///   indexing_maps = #accesses,
48 ///   iterator_types = ["parallel", "parallel"],
49 ///   library_call = "some_external_fn"
50 /// }
51 ///
52 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
53 /// tensor<5x5xf32>
54 /// {
55 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
56 ///        tensor<5xf32> into tensor<1x5xf32>
57 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
58 ///        tensor<5xf32> into tensor<5x1xf32>
59 ///   %2 = linalg.generic #trait %0, %1 {
60 ///        ^bb0(%arg2: f32, %arg3: f32):
61 ///          %3 = addf %arg2, %arg3 : f32
62 ///          linalg.yield %3 : f32
63 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
64 ///   return %2 : tensor<5x5xf32>
65 /// }
66 ///
67 /// would canonicalize to
68 ///
69 /// ```mlir
70 /// #accesses = [
71 ///   affine_map<(d0, d1) -> (d1)>,
72 ///   affine_map<(d0, d1) -> (d0)>,
73 ///   affine_map<(d0, d1) -> (d0, d1)>
74 /// ]
75 ///
76 /// #trait = {
77 ///   args_in = 2,
78 ///   args_out = 1,
79 ///   indexing_maps = #accesses,
80 ///   iterator_types = ["parallel", "parallel"],
81 ///   library_call = "some_external_fn"
82 /// }
83 ///
84 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
85 /// tensor<5x5xf32>
86 /// {
87 ///   %0 = linalg.generic #trait %arg0, %arg1 {
88 ///        ^bb0(%arg2: f32, %arg3: f32):
89 ///          %3 = addf %arg2, %arg3 : f32
90 ///          linalg.yield %3 : f32
91 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
92 ///   return %0 : tensor<5x5xf32>
93 /// }
94 
95 /// Given dims of the iteration space of a structured op that are known to be
96 /// single trip count (`unitDims`), return the indexing maps to use in the
97 /// canonicalized op with these dims removed, given the original `indexingMaps`.
98 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
99                                  ArrayRef<AffineMap> indexingMaps,
100                                  MLIRContext *context) {
101   if (indexingMaps.empty())
102     return nullptr;
103   unsigned numIterationDims = indexingMaps.front().getNumDims();
104   unsigned numSymbols = indexingMaps.front().getNumSymbols();
105 
106   // Compute the replacement for each dim expr.
107   SmallVector<AffineExpr, 4> dimReplacements;
108   dimReplacements.reserve(numIterationDims);
109   unsigned numKeptDims = 0;
110   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
111     if (unitDims.count(dim))
112       dimReplacements.push_back(getAffineConstantExpr(0, context));
113     else
114       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
115   }
116 
117   // Symbols remain the same.
118   SmallVector<AffineExpr, 4> symReplacements;
119   symReplacements.reserve(numSymbols);
120   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
121     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
122 
123   SmallVector<AffineMap, 4> newIndexingMaps;
124   newIndexingMaps.reserve(indexingMaps.size());
125   for (AffineMap operandMap : indexingMaps) {
126     // Expected indexing maps to have no symbols.
127     if (operandMap.getNumSymbols())
128       return nullptr;
129     newIndexingMaps.push_back(simplifyAffineMap(
130         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
131                                          numIterationDims - unitDims.size(),
132                                          numSymbols)));
133   }
134 
135   // Check that the new index maps are invertible. If not, something went
136   // wrong, so abort.
137   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
138     return nullptr;
139   return ArrayAttr::get(context,
140                         llvm::to_vector<4>(llvm::map_range(
141                             newIndexingMaps, [](AffineMap map) -> Attribute {
142                               return AffineMapAttr::get(map);
143                             })));
144 }
145 
146 /// Update the index accesses of linalg operations having index semantics.
147 static void replaceUnitDimIndexOps(GenericOp genericOp,
148                                    const DenseSet<unsigned> &unitDims,
149                                    PatternRewriter &rewriter) {
150   assert(genericOp->getNumRegions() == 1 &&
151          genericOp->getRegion(0).getBlocks().size() == 1 &&
152          "expected generic operation to have one block.");
153   Block &block = genericOp->getRegion(0).front();
154 
155   for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
156     OpBuilder::InsertionGuard guard(rewriter);
157     rewriter.setInsertionPoint(indexOp);
158     if (unitDims.count(indexOp.dim()) != 0) {
159       rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0);
160     } else {
161       // Update the dimension of the index operation if needed.
162       unsigned droppedDims = llvm::count_if(
163           unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
164       if (droppedDims != 0)
165         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
166                                              indexOp.dim() - droppedDims);
167     }
168   }
169 }
170 
171 namespace {
172 /// Pattern to fold unit-trip count loops in GenericOps.
173 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
174   using OpRewritePattern<GenericOp>::OpRewritePattern;
175   LogicalResult matchAndRewrite(GenericOp genericOp,
176                                 PatternRewriter &rewriter) const override {
177     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
178     if (indexingMaps.empty())
179       return failure();
180 
181     // Check if any of the iteration dimensions are unit-trip count. They will
182     // end up being unit-trip count if they are used to index into a unit-dim
183     // tensor/memref.
184     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
185     if (!invertedMap)
186       return failure();
187     SmallVector<int64_t> dims = genericOp.getStaticShape();
188 
189     // Find all the reduction iterators. Those need some special consideration
190     // (see below).
191     auto getLoopDimsOfType =
192         [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
193       SmallVector<AffineExpr> dimExprs;
194       getDimsOfType(genericOp, iteratorTypeName, dimExprs);
195       return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
196         return expr.cast<AffineDimExpr>().getPosition();
197       }));
198     };
199     auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
200 
201     DenseSet<unsigned> unitDims;
202     SmallVector<unsigned, 4> unitDimsReductionLoops;
203     ArrayAttr iteratorTypes = genericOp.iterator_types();
204     for (auto expr : enumerate(invertedMap.getResults())) {
205       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
206         if (dims[dimExpr.getPosition()] == 1) {
207           if (isParallelIterator(iteratorTypes[expr.index()]))
208             unitDims.insert(expr.index());
209           else if (isReductionIterator(iteratorTypes[expr.index()]))
210             unitDimsReductionLoops.push_back(expr.index());
211         }
212     }
213 
214     // Reduction loops can be dropped if there is at least one other reduction
215     // loop that is not dropped. This accounts for the initial value read in the
216     // reduction loop.
217     if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
218       if (unitDimsReductionLoops.size() == reductionDims.size())
219         unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
220       else
221         unitDims.insert(unitDimsReductionLoops.begin(),
222                         unitDimsReductionLoops.end());
223     }
224 
225     if (unitDims.empty())
226       return failure();
227 
228     // Compute the modified indexing maps.
229     MLIRContext *context = rewriter.getContext();
230     ArrayAttr newIndexingMapAttr =
231         replaceUnitDims(unitDims, indexingMaps, context);
232     if (!newIndexingMapAttr)
233       return genericOp.emitError("unable to compute modified indexing_maps");
234 
235     // Compute the iterator types of the modified op by dropping the one-trip
236     // count loops.
237     SmallVector<Attribute, 4> newIteratorTypes;
238     for (auto attr : llvm::enumerate(iteratorTypes)) {
239       if (!unitDims.count(attr.index()))
240         newIteratorTypes.push_back(attr.value());
241     }
242 
243     rewriter.startRootUpdate(genericOp);
244     genericOp.indexing_mapsAttr(newIndexingMapAttr);
245     genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
246     replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
247     rewriter.finalizeRootUpdate(genericOp);
248     return success();
249   }
250 };
251 
252 struct UnitExtentReplacementInfo {
253   Type type;
254   AffineMap indexMap;
255   ArrayAttr reassociation;
256 };
257 } // namespace
258 
259 /// Utility function for replacing operands/results to a linalg generic
260 /// operation with unit-extent dimensions. These can be replaced with
261 /// an operand/result with the unit-extent dimension removed. This is only done
262 /// if the indexing map used to access that didimensionmension has a
263 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
264 /// Linalg op, and its `indexMap` the utility function returns:
265 /// - the new type with dimensions of size 1 removed.
266 /// - modified index map that can be used to access the replaced result/operand
267 /// - the reassociation that converts from the original tensor type to the
268 ///   modified tensor type.
269 static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
270                                                     OpOperand *opOperand,
271                                                     MLIRContext *context) {
272   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
273   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
274   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
275   SmallVector<AffineExpr> reassociations;
276   SmallVector<Attribute> reassociationMaps;
277   SmallVector<AffineExpr> newIndexExprs;
278   SmallVector<int64_t> newShape;
279 
280   int64_t origRank = genericOp.getRank(opOperand);
281   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
282   auto isUnitExtent = [&](int64_t dim) -> bool {
283     return shape[dim] == 1 && exprs[dim] == zeroExpr;
284   };
285 
286   int64_t dim = 0;
287   // Fold dimensions that are unit-extent at the beginning of the tensor.
288   while (dim < origRank && isUnitExtent(dim))
289     reassociations.push_back(getAffineDimExpr(dim++, context));
290   while (dim < origRank) {
291     reassociations.push_back(getAffineDimExpr(dim, context));
292     newIndexExprs.push_back(exprs[dim]);
293     newShape.push_back(shape[dim]);
294     // Fold all following dimensions that are unit-extent.
295     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
296       ++dim;
297       reassociations.push_back(getAffineDimExpr(dim, context));
298     }
299     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
300         origRank, /*symbolCount = */ 0, reassociations, context)));
301     reassociations.clear();
302     ++dim;
303   }
304   // Compute the tensor or scalar replacement type.
305   Type actualType = opOperand->get().getType();
306   Type elementType = getElementTypeOrSelf(opOperand->get());
307   Type replacementType;
308   if (elementType == opOperand->get().getType()) {
309     replacementType = elementType;
310   } else if (actualType.isa<RankedTensorType>()) {
311     replacementType = RankedTensorType::get(newShape, elementType);
312   } else if (actualType.isa<MemRefType>()) {
313     assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
314            "unsupported strided memrefs");
315     replacementType = MemRefType::get(newShape, elementType);
316   }
317   assert(replacementType && "unsupported shaped type");
318   UnitExtentReplacementInfo info = {replacementType,
319                                     AffineMap::get(indexingMap.getNumDims(),
320                                                    indexingMap.getNumSymbols(),
321                                                    newIndexExprs, context),
322                                     ArrayAttr::get(context, reassociationMaps)};
323   return info;
324 }
325 
326 namespace {
327 
328 SmallVector<ReassociationExprs, 2>
329 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
330   SmallVector<ReassociationExprs, 2> reassociationExprs;
331   for (auto attr : affineMapArrayAttr)
332     reassociationExprs.push_back(
333         llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
334   return reassociationExprs;
335 }
336 
337 /// Pattern to replace tensor/buffer operands/results that are unit extents.
338 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
339   using OpRewritePattern<GenericOp>::OpRewritePattern;
340 
341   // Return the original value if the type is unchanged, or reshape it. Return a
342   // nullptr if this is an unsupported type.
343   Value maybeExpand(Value result, Type origResultType,
344                     ArrayAttr reassociationMap, Location loc,
345                     PatternRewriter &rewriter) const {
346     if (origResultType == result.getType())
347       return result;
348     if (origResultType.isa<RankedTensorType>()) {
349       return rewriter.create<linalg::TensorExpandShapeOp>(
350           loc, origResultType, result,
351           convertAffineMapArrayToExprs(reassociationMap));
352     }
353     if (origResultType.isa<MemRefType>()) {
354       return rewriter.create<linalg::ExpandShapeOp>(
355           loc, origResultType, result,
356           convertAffineMapArrayToExprs(reassociationMap));
357     }
358     return nullptr;
359   };
360 
361   // Return the original value if the type is unchanged, or reshape it. Return a
362   // nullptr if this is an unsupported type.
363   Value maybeCollapse(Value operand, Type newInputOutputType,
364                       ArrayAttr reassociationMap, Location loc,
365                       PatternRewriter &rewriter) const {
366     auto operandType = operand.getType();
367     if (operandType == newInputOutputType)
368       return operand;
369     if (operandType.isa<MemRefType>()) {
370       return rewriter.create<linalg::CollapseShapeOp>(
371           loc, newInputOutputType, operand,
372           convertAffineMapArrayToExprs(reassociationMap));
373     }
374     if (operandType.isa<RankedTensorType>()) {
375       return rewriter.create<linalg::TensorCollapseShapeOp>(
376           loc, newInputOutputType, operand,
377           convertAffineMapArrayToExprs(reassociationMap));
378     }
379     return nullptr;
380   };
381 
382   LogicalResult matchAndRewrite(GenericOp genericOp,
383                                 PatternRewriter &rewriter) const override {
384     MLIRContext *context = rewriter.getContext();
385     Location loc = genericOp.getLoc();
386 
387     SmallVector<AffineMap> newIndexingMaps;
388     SmallVector<ArrayAttr> reassociationMaps;
389     SmallVector<Type> newInputOutputTypes;
390     bool doCanonicalization = false;
391     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
392       UnitExtentReplacementInfo replacementInfo =
393           replaceUnitExtents(genericOp, opOperand, context);
394       reassociationMaps.push_back(replacementInfo.reassociation);
395       newIndexingMaps.push_back(replacementInfo.indexMap);
396       newInputOutputTypes.push_back(replacementInfo.type);
397       doCanonicalization |= replacementInfo.type != opOperand->get().getType();
398     }
399 
400     // If the indexing maps of the result operation are not invertible (i.e. not
401     // legal), abort.
402     if (!doCanonicalization ||
403         !inversePermutation(concatAffineMaps(newIndexingMaps)))
404       return failure();
405 
406     // If any operand type change, insert a reshape to convert from the original
407     // type to the new type.
408     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
409     unsigned flattenedIdx = 0;
410     auto insertReshapes = [&](ValueRange values) {
411       SmallVector<Value, 4> res;
412       res.reserve(values.size());
413       for (auto operand : values) {
414         auto reshapedValue =
415             maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
416                           reassociationMaps[flattenedIdx], loc, rewriter);
417         assert(reshapedValue &&
418                "expected ranked MemRef or Tensor operand type");
419         res.push_back(reshapedValue);
420         ++flattenedIdx;
421       }
422       return res;
423     };
424 
425     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
426     SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
427 
428     // If any result type changes, insert a reshape to convert from the original
429     // type to the new type.
430     SmallVector<Type, 4> resultTypes;
431     resultTypes.reserve(genericOp.getNumResults());
432     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
433       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
434     GenericOp replacementOp = rewriter.create<GenericOp>(
435         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
436         llvm::to_vector<4>(
437             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
438     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
439                                 replacementOp.region().begin());
440 
441     // If any result tensor has a modified shape, then add reshape to recover
442     // the original shape.
443     SmallVector<Value, 4> resultReplacements;
444     for (auto result : llvm::enumerate(replacementOp.getResults())) {
445       unsigned index = result.index() + replacementOp.getNumInputs();
446       auto origResultType = genericOp.getResult(result.index()).getType();
447 
448       auto newResult = maybeExpand(result.value(), origResultType,
449                                    reassociationMaps[index], loc, rewriter);
450       assert(newResult &&
451              "unexpected output type other than ranked MemRef or Tensor");
452       resultReplacements.push_back(newResult);
453     }
454     rewriter.replaceOp(genericOp, resultReplacements);
455     return success();
456   }
457 };
458 } // namespace
459 
460 /// Get the reassociation maps to fold the result of a subtensor (or source of a
461 /// subtensor_insert) operation with given offsets, and sizes to its
462 /// rank-reduced version. This is only done for the cases where the size is 1
463 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
464 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
465 /// potentially cannot be handled).
466 static Optional<SmallVector<ReassociationIndices>>
467 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
468   SmallVector<ReassociationIndices> reassociation;
469   ReassociationIndices curr;
470   for (auto it : llvm::enumerate(mixedSizes)) {
471     auto dim = it.index();
472     auto size = it.value();
473     curr.push_back(dim);
474     auto attr = size.dyn_cast<Attribute>();
475     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
476       continue;
477     reassociation.emplace_back(ReassociationIndices{});
478     std::swap(reassociation.back(), curr);
479   }
480   // When the reassociations are not empty, then fold the remaining
481   // unit-dimensions into the last dimension.  If the reassociations so far is
482   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
483   if (!curr.empty() && !reassociation.empty())
484     reassociation.back().append(curr.begin(), curr.end());
485   return reassociation;
486 }
487 
488 namespace {
489 /// Convert `subtensor` operations to rank-reduced versions.
490 struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
491   using OpRewritePattern<SubTensorOp>::OpRewritePattern;
492 
493   LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
494                                 PatternRewriter &rewriter) const override {
495     RankedTensorType resultType = subTensorOp.getType();
496     SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets();
497     SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes();
498     SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides();
499     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
500     if (!reassociation ||
501         reassociation->size() == static_cast<size_t>(resultType.getRank()))
502       return failure();
503     auto rankReducedType =
504         SubTensorOp::inferRankReducedResultType(reassociation->size(),
505                                                 subTensorOp.getSourceType(),
506                                                 offsets, sizes, strides)
507             .cast<RankedTensorType>();
508 
509     Location loc = subTensorOp.getLoc();
510     Value newSubTensor = rewriter.create<SubTensorOp>(
511         loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
512     rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(
513         subTensorOp, resultType, newSubTensor, *reassociation);
514     return success();
515   }
516 };
517 
518 /// Convert `subtensor_insert` operations to rank-reduced versions.
519 struct UseRankReducedSubTensorInsertOp
520     : public OpRewritePattern<SubTensorInsertOp> {
521   using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
522 
523   LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
524                                 PatternRewriter &rewriter) const override {
525     RankedTensorType sourceType = insertOp.getSourceType();
526     SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
527     SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
528     SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
529     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
530     if (!reassociation ||
531         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
532       return failure();
533     Location loc = insertOp.getLoc();
534     auto reshapedSource = rewriter.create<TensorCollapseShapeOp>(
535         loc, insertOp.source(), *reassociation);
536     rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
537         insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
538         insertOp.getMixedSizes(), insertOp.getMixedStrides());
539     return success();
540   }
541 };
542 } // namespace
543 
544 /// Patterns that are used to canonicalize the use of unit-extent dims for
545 /// broadcasting.
546 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
547     RewritePatternSet &patterns) {
548   auto *context = patterns.getContext();
549   patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, UseRankReducedSubTensorOp,
550                UseRankReducedSubTensorInsertOp>(context);
551   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
552   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
553 }
554 
555 namespace {
556 /// Pass that removes unit-extent dims within generic ops.
557 struct LinalgFoldUnitExtentDimsPass
558     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
559   void runOnFunction() override {
560     FuncOp funcOp = getFunction();
561     MLIRContext *context = funcOp.getContext();
562     RewritePatternSet patterns(context);
563     if (foldOneTripLoopsOnly)
564       patterns.add<FoldUnitDimLoops>(context);
565     else
566       populateFoldUnitExtentDimsPatterns(patterns);
567     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
568   }
569 };
570 } // namespace
571 
572 std::unique_ptr<OperationPass<FuncOp>>
573 mlir::createLinalgFoldUnitExtentDimsPass() {
574   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
575 }
576