1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "linalg-drop-unit-dims"
32 
33 using namespace mlir;
34 using namespace mlir::linalg;
35 
36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
37 /// broadcasting. For example,
38 ///
39 /// ```mlir
40 /// #accesses = [
41 ///   affine_map<(d0, d1) -> (0, d1)>,
42 ///   affine_map<(d0, d1) -> (d0, 0)>,
43 ///   affine_map<(d0, d1) -> (d0, d1)>
44 /// ]
45 ///
46 /// #trait = {
47 ///   args_in = 2,
48 ///   args_out = 1,
49 ///   indexing_maps = #accesses,
50 ///   iterator_types = ["parallel", "parallel"],
51 ///   library_call = "some_external_fn"
52 /// }
53 ///
54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
55 /// tensor<5x5xf32>
56 /// {
57 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
58 ///        tensor<5xf32> into tensor<1x5xf32>
59 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
60 ///        tensor<5xf32> into tensor<5x1xf32>
61 ///   %2 = linalg.generic #trait %0, %1 {
62 ///        ^bb0(%arg2: f32, %arg3: f32):
63 ///          %3 = arith.addf %arg2, %arg3 : f32
64 ///          linalg.yield %3 : f32
65 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
66 ///   return %2 : tensor<5x5xf32>
67 /// }
68 ///
69 /// would canonicalize to
70 ///
71 /// ```mlir
72 /// #accesses = [
73 ///   affine_map<(d0, d1) -> (d1)>,
74 ///   affine_map<(d0, d1) -> (d0)>,
75 ///   affine_map<(d0, d1) -> (d0, d1)>
76 /// ]
77 ///
78 /// #trait = {
79 ///   args_in = 2,
80 ///   args_out = 1,
81 ///   indexing_maps = #accesses,
82 ///   iterator_types = ["parallel", "parallel"],
83 ///   library_call = "some_external_fn"
84 /// }
85 ///
86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
87 /// tensor<5x5xf32>
88 /// {
89 ///   %0 = linalg.generic #trait %arg0, %arg1 {
90 ///        ^bb0(%arg2: f32, %arg3: f32):
91 ///          %3 = arith.addf %arg2, %arg3 : f32
92 ///          linalg.yield %3 : f32
93 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
94 ///   return %0 : tensor<5x5xf32>
95 /// }
96 
97 /// Given dims of the iteration space of a structured op that are known to be
98 /// single trip count (`unitDims`), return the indexing maps to use in the
99 /// canonicalized op with these dims removed, given the original `indexingMaps`.
100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
101                                  ArrayRef<AffineMap> indexingMaps,
102                                  MLIRContext *context) {
103   if (indexingMaps.empty())
104     return nullptr;
105   unsigned numIterationDims = indexingMaps.front().getNumDims();
106   unsigned numSymbols = indexingMaps.front().getNumSymbols();
107 
108   // Compute the replacement for each dim expr.
109   SmallVector<AffineExpr, 4> dimReplacements;
110   dimReplacements.reserve(numIterationDims);
111   unsigned numKeptDims = 0;
112   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
113     if (unitDims.count(dim))
114       dimReplacements.push_back(getAffineConstantExpr(0, context));
115     else
116       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
117   }
118 
119   // Symbols remain the same.
120   SmallVector<AffineExpr, 4> symReplacements;
121   symReplacements.reserve(numSymbols);
122   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
123     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
124 
125   SmallVector<AffineMap, 4> newIndexingMaps;
126   newIndexingMaps.reserve(indexingMaps.size());
127   for (AffineMap operandMap : indexingMaps) {
128     // Expected indexing maps to have no symbols.
129     if (operandMap.getNumSymbols())
130       return nullptr;
131     newIndexingMaps.push_back(simplifyAffineMap(
132         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
133                                          numIterationDims - unitDims.size(),
134                                          numSymbols)));
135   }
136 
137   // Check that the new index maps are invertible. If not, something went
138   // wrong, so abort.
139   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
140     return nullptr;
141   return ArrayAttr::get(context,
142                         llvm::to_vector<4>(llvm::map_range(
143                             newIndexingMaps, [](AffineMap map) -> Attribute {
144                               return AffineMapAttr::get(map);
145                             })));
146 }
147 
148 /// Update the index accesses of linalg operations having index semantics.
149 static void replaceUnitDimIndexOps(GenericOp genericOp,
150                                    const DenseSet<unsigned> &unitDims,
151                                    PatternRewriter &rewriter) {
152   assert(genericOp->getNumRegions() == 1 &&
153          genericOp->getRegion(0).getBlocks().size() == 1 &&
154          "expected generic operation to have one block.");
155   Block &block = genericOp->getRegion(0).front();
156 
157   for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
158     OpBuilder::InsertionGuard guard(rewriter);
159     rewriter.setInsertionPoint(indexOp);
160     if (unitDims.count(indexOp.dim()) != 0) {
161       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
162     } else {
163       // Update the dimension of the index operation if needed.
164       unsigned droppedDims = llvm::count_if(
165           unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
166       if (droppedDims != 0)
167         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
168                                              indexOp.dim() - droppedDims);
169     }
170   }
171 }
172 
173 namespace {
174 /// Pattern to fold unit-trip count loops in GenericOps.
175 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
176   using OpRewritePattern<GenericOp>::OpRewritePattern;
177   LogicalResult matchAndRewrite(GenericOp genericOp,
178                                 PatternRewriter &rewriter) const override {
179     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
180     if (indexingMaps.empty())
181       return failure();
182 
183     // Check if any of the iteration dimensions are unit-trip count. They will
184     // end up being unit-trip count if they are used to index into a unit-dim
185     // tensor/memref.
186     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
187     if (!invertedMap)
188       return failure();
189     SmallVector<int64_t> dims = genericOp.getStaticShape();
190 
191     DenseSet<unsigned> unitDims;
192     SmallVector<unsigned, 4> unitDimsReductionLoops;
193     ArrayAttr iteratorTypes = genericOp.iterator_types();
194     for (auto expr : enumerate(invertedMap.getResults())) {
195       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
196         if (dims[dimExpr.getPosition()] == 1)
197           unitDims.insert(expr.index());
198     }
199 
200     if (unitDims.empty())
201       return failure();
202 
203     // Compute the modified indexing maps.
204     MLIRContext *context = rewriter.getContext();
205     ArrayAttr newIndexingMapAttr =
206         replaceUnitDims(unitDims, indexingMaps, context);
207     if (!newIndexingMapAttr)
208       return genericOp.emitError("unable to compute modified indexing_maps");
209 
210     // Compute the iterator types of the modified op by dropping the one-trip
211     // count loops.
212     SmallVector<Attribute, 4> newIteratorTypes;
213     for (auto attr : llvm::enumerate(iteratorTypes)) {
214       if (!unitDims.count(attr.index()))
215         newIteratorTypes.push_back(attr.value());
216     }
217 
218     rewriter.startRootUpdate(genericOp);
219     genericOp.indexing_mapsAttr(newIndexingMapAttr);
220     genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
221     replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
222     rewriter.finalizeRootUpdate(genericOp);
223     return success();
224   }
225 };
226 
227 struct UnitExtentReplacementInfo {
228   Type type;
229   AffineMap indexMap;
230   ArrayAttr reassociation;
231 };
232 } // namespace
233 
234 /// Utility function for replacing operands/results to a linalg generic
235 /// operation with unit-extent dimensions. These can be replaced with
236 /// an operand/result with the unit-extent dimension removed. This is only done
237 /// if the indexing map used to access that didimensionmension has a
238 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
239 /// Linalg op, and its `indexMap` the utility function returns:
240 /// - the new type with dimensions of size 1 removed.
241 /// - modified index map that can be used to access the replaced result/operand
242 /// - the reassociation that converts from the original tensor type to the
243 ///   modified tensor type.
244 static llvm::Optional<UnitExtentReplacementInfo>
245 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
246                    MLIRContext *context) {
247   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
248   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
249   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
250   SmallVector<AffineExpr> reassociations;
251   SmallVector<Attribute> reassociationMaps;
252   SmallVector<AffineExpr> newIndexExprs;
253   SmallVector<int64_t> newShape;
254 
255   int64_t origRank = genericOp.getRank(opOperand);
256   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
257   auto isUnitExtent = [&](int64_t dim) -> bool {
258     return shape[dim] == 1 && exprs[dim] == zeroExpr;
259   };
260 
261   // Early return for memrefs with affine maps to represent that we will always
262   // leave them unchanged.
263   Type actualType = opOperand->get().getType();
264   if (auto memref = actualType.dyn_cast<MemRefType>()) {
265     if (!memref.getAffineMaps().empty())
266       return llvm::None;
267   }
268 
269   int64_t dim = 0;
270   // Fold dimensions that are unit-extent at the beginning of the tensor.
271   while (dim < origRank && isUnitExtent(dim))
272     reassociations.push_back(getAffineDimExpr(dim++, context));
273   while (dim < origRank) {
274     reassociations.push_back(getAffineDimExpr(dim, context));
275     newIndexExprs.push_back(exprs[dim]);
276     newShape.push_back(shape[dim]);
277     // Fold all following dimensions that are unit-extent.
278     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
279       ++dim;
280       reassociations.push_back(getAffineDimExpr(dim, context));
281     }
282     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
283         origRank, /*symbolCount = */ 0, reassociations, context)));
284     reassociations.clear();
285     ++dim;
286   }
287 
288   // Compute the tensor or scalar replacement type.
289   Type elementType = getElementTypeOrSelf(opOperand->get());
290   Type replacementType;
291   if (elementType == opOperand->get().getType()) {
292     replacementType = elementType;
293   } else if (actualType.isa<RankedTensorType>()) {
294     replacementType = RankedTensorType::get(newShape, elementType);
295   } else if (actualType.isa<MemRefType>()) {
296     replacementType = MemRefType::get(newShape, elementType);
297   }
298   assert(replacementType && "unsupported shaped type");
299   UnitExtentReplacementInfo info = {replacementType,
300                                     AffineMap::get(indexingMap.getNumDims(),
301                                                    indexingMap.getNumSymbols(),
302                                                    newIndexExprs, context),
303                                     ArrayAttr::get(context, reassociationMaps)};
304   return info;
305 }
306 
307 namespace {
308 
309 SmallVector<ReassociationExprs, 2>
310 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
311   SmallVector<ReassociationExprs, 2> reassociationExprs;
312   for (auto attr : affineMapArrayAttr)
313     reassociationExprs.push_back(
314         llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
315   return reassociationExprs;
316 }
317 
318 /// Pattern to replace tensor/buffer operands/results that are unit extents.
319 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
320   using OpRewritePattern<GenericOp>::OpRewritePattern;
321 
322   // Return the original value if the type is unchanged, or reshape it. Return a
323   // nullptr if this is an unsupported type.
324   Value maybeExpand(Value result, Type origResultType,
325                     ArrayAttr reassociationMap, Location loc,
326                     PatternRewriter &rewriter) const {
327     if (origResultType == result.getType())
328       return result;
329     if (origResultType.isa<RankedTensorType>()) {
330       return rewriter.create<linalg::TensorExpandShapeOp>(
331           loc, origResultType, result,
332           convertAffineMapArrayToExprs(reassociationMap));
333     }
334     if (origResultType.isa<MemRefType>()) {
335       return rewriter.create<memref::ExpandShapeOp>(
336           loc, origResultType, result,
337           convertAffineMapArrayToExprs(reassociationMap));
338     }
339     return nullptr;
340   };
341 
342   // Return the original value if the type is unchanged, or reshape it. Return a
343   // nullptr if this is an unsupported type.
344   Value maybeCollapse(Value operand, Type newInputOutputType,
345                       ArrayAttr reassociationMap, Location loc,
346                       PatternRewriter &rewriter) const {
347     auto operandType = operand.getType();
348     if (operandType == newInputOutputType)
349       return operand;
350     if (operandType.isa<MemRefType>()) {
351       return rewriter.create<memref::CollapseShapeOp>(
352           loc, newInputOutputType, operand,
353           convertAffineMapArrayToExprs(reassociationMap));
354     }
355     if (operandType.isa<RankedTensorType>()) {
356       return rewriter.create<linalg::TensorCollapseShapeOp>(
357           loc, newInputOutputType, operand,
358           convertAffineMapArrayToExprs(reassociationMap));
359     }
360     return nullptr;
361   };
362 
363   LogicalResult matchAndRewrite(GenericOp genericOp,
364                                 PatternRewriter &rewriter) const override {
365     // Skip the pattern if the op has any tensor with special encoding.
366     if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
367           auto tensorType = type.dyn_cast<RankedTensorType>();
368           return tensorType && tensorType.getEncoding() != nullptr;
369         }))
370       return failure();
371     MLIRContext *context = rewriter.getContext();
372     Location loc = genericOp.getLoc();
373 
374     SmallVector<AffineMap> newIndexingMaps;
375     SmallVector<ArrayAttr> reassociationMaps;
376     SmallVector<Type> newInputOutputTypes;
377     bool doCanonicalization = false;
378     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
379       auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
380       if (replacementInfo) {
381         reassociationMaps.push_back(replacementInfo->reassociation);
382         newIndexingMaps.push_back(replacementInfo->indexMap);
383         newInputOutputTypes.push_back(replacementInfo->type);
384         doCanonicalization |=
385             replacementInfo->type != opOperand->get().getType();
386       } else {
387         // If replaceUnitExtents cannot handle this case, maintain the same
388         // type, indexing map, and create a set of mappings representing an
389         // identity matrix.
390         newInputOutputTypes.push_back(opOperand->get().getType());
391         newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
392         int64_t origRank = genericOp.getRank(opOperand);
393         auto maps = llvm::to_vector<8>(llvm::map_range(
394             llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
395               return AffineMapAttr::get(
396                   AffineMap::get(origRank, /*symbolCount = */ 0,
397                                  getAffineDimExpr(dim, context), context));
398             }));
399         reassociationMaps.push_back(ArrayAttr::get(context, maps));
400       }
401     }
402 
403     // If the indexing maps of the result operation are not invertible (i.e. not
404     // legal), abort.
405     if (!doCanonicalization ||
406         !inversePermutation(concatAffineMaps(newIndexingMaps)))
407       return failure();
408 
409     // If any operand type change, insert a reshape to convert from the original
410     // type to the new type.
411     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
412     unsigned flattenedIdx = 0;
413     auto insertReshapes = [&](ValueRange values) {
414       SmallVector<Value, 4> res;
415       res.reserve(values.size());
416       for (auto operand : values) {
417         auto reshapedValue =
418             maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
419                           reassociationMaps[flattenedIdx], loc, rewriter);
420         assert(reshapedValue &&
421                "expected ranked MemRef or Tensor operand type");
422         res.push_back(reshapedValue);
423         ++flattenedIdx;
424       }
425       return res;
426     };
427 
428     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
429     SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
430 
431     // If any result type changes, insert a reshape to convert from the original
432     // type to the new type.
433     SmallVector<Type, 4> resultTypes;
434     resultTypes.reserve(genericOp.getNumResults());
435     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
436       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
437     GenericOp replacementOp = rewriter.create<GenericOp>(
438         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
439         llvm::to_vector<4>(
440             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
441     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
442                                 replacementOp.region().begin());
443 
444     // If any result tensor has a modified shape, then add reshape to recover
445     // the original shape.
446     SmallVector<Value, 4> resultReplacements;
447     for (auto result : llvm::enumerate(replacementOp.getResults())) {
448       unsigned index = result.index() + replacementOp.getNumInputs();
449       auto origResultType = genericOp.getResult(result.index()).getType();
450 
451       auto newResult = maybeExpand(result.value(), origResultType,
452                                    reassociationMaps[index], loc, rewriter);
453       assert(newResult &&
454              "unexpected output type other than ranked MemRef or Tensor");
455       resultReplacements.push_back(newResult);
456     }
457     rewriter.replaceOp(genericOp, resultReplacements);
458     return success();
459   }
460 };
461 } // namespace
462 
463 /// Get the reassociation maps to fold the result of a extract_slice (or source
464 /// of a insert_slice) operation with given offsets, and sizes to its
465 /// rank-reduced version. This is only done for the cases where the size is 1
466 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
467 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
468 /// potentially cannot be handled).
469 static Optional<SmallVector<ReassociationIndices>>
470 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
471   SmallVector<ReassociationIndices> reassociation;
472   ReassociationIndices curr;
473   for (auto it : llvm::enumerate(mixedSizes)) {
474     auto dim = it.index();
475     auto size = it.value();
476     curr.push_back(dim);
477     auto attr = size.dyn_cast<Attribute>();
478     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
479       continue;
480     reassociation.emplace_back(ReassociationIndices{});
481     std::swap(reassociation.back(), curr);
482   }
483   // When the reassociations are not empty, then fold the remaining
484   // unit-dimensions into the last dimension.  If the reassociations so far is
485   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
486   if (!curr.empty() && !reassociation.empty())
487     reassociation.back().append(curr.begin(), curr.end());
488   return reassociation;
489 }
490 
491 namespace {
492 /// Convert `extract_slice` operations to rank-reduced versions.
493 struct UseRankReducedExtractSliceOp
494     : public OpRewritePattern<tensor::ExtractSliceOp> {
495   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
496 
497   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
498                                 PatternRewriter &rewriter) const override {
499     RankedTensorType resultType = sliceOp.getType();
500     SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
501     SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
502     SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
503     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
504     if (!reassociation ||
505         reassociation->size() == static_cast<size_t>(resultType.getRank()))
506       return failure();
507     auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType(
508                                reassociation->size(), sliceOp.getSourceType(),
509                                offsets, sizes, strides)
510                                .cast<RankedTensorType>();
511 
512     Location loc = sliceOp.getLoc();
513     Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
514         loc, rankReducedType, sliceOp.source(), offsets, sizes, strides);
515     rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(sliceOp, resultType,
516                                                      newSlice, *reassociation);
517     return success();
518   }
519 };
520 
521 /// Convert `insert_slice` operations to rank-reduced versions.
522 struct UseRankReducedInsertSliceOp
523     : public OpRewritePattern<tensor::InsertSliceOp> {
524   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
525 
526   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
527                                 PatternRewriter &rewriter) const override {
528     RankedTensorType sourceType = insertOp.getSourceType();
529     SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
530     SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
531     SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
532     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
533     if (!reassociation ||
534         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
535       return failure();
536     Location loc = insertOp.getLoc();
537     auto reshapedSource = rewriter.create<TensorCollapseShapeOp>(
538         loc, insertOp.source(), *reassociation);
539     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
540         insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
541         insertOp.getMixedSizes(), insertOp.getMixedStrides());
542     return success();
543   }
544 };
545 } // namespace
546 
547 /// Patterns that are used to canonicalize the use of unit-extent dims for
548 /// broadcasting.
549 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
550     RewritePatternSet &patterns) {
551   auto *context = patterns.getContext();
552   patterns.add<FoldUnitDimLoops, ReplaceUnitExtents,
553                UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>(
554       context);
555   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
556   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
557 }
558 
559 namespace {
560 /// Pass that removes unit-extent dims within generic ops.
561 struct LinalgFoldUnitExtentDimsPass
562     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
563   void runOnFunction() override {
564     FuncOp funcOp = getFunction();
565     MLIRContext *context = funcOp.getContext();
566     RewritePatternSet patterns(context);
567     if (foldOneTripLoopsOnly)
568       patterns.add<FoldUnitDimLoops>(context);
569     else
570       populateFoldUnitExtentDimsPatterns(patterns);
571     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
572   }
573 };
574 } // namespace
575 
576 std::unique_ptr<OperationPass<FuncOp>>
577 mlir::createLinalgFoldUnitExtentDimsPass() {
578   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
579 }
580