1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/AffineMap.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/Transforms/FoldUtils.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "linalg-drop-unit-dims"
32 
33 using namespace mlir;
34 using namespace mlir::linalg;
35 
36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
37 /// broadcasting. For example,
38 ///
39 /// ```mlir
40 /// #accesses = [
41 ///   affine_map<(d0, d1) -> (0, d1)>,
42 ///   affine_map<(d0, d1) -> (d0, 0)>,
43 ///   affine_map<(d0, d1) -> (d0, d1)>
44 /// ]
45 ///
46 /// #trait = {
47 ///   args_in = 2,
48 ///   args_out = 1,
49 ///   indexing_maps = #accesses,
50 ///   iterator_types = ["parallel", "parallel"],
51 ///   library_call = "some_external_fn"
52 /// }
53 ///
54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
55 /// tensor<5x5xf32>
56 /// {
57 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
58 ///        tensor<5xf32> into tensor<1x5xf32>
59 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
60 ///        tensor<5xf32> into tensor<5x1xf32>
61 ///   %2 = linalg.generic #trait %0, %1 {
62 ///        ^bb0(%arg2: f32, %arg3: f32):
63 ///          %3 = arith.addf %arg2, %arg3 : f32
64 ///          linalg.yield %3 : f32
65 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
66 ///   return %2 : tensor<5x5xf32>
67 /// }
68 ///
69 /// would canonicalize to
70 ///
71 /// ```mlir
72 /// #accesses = [
73 ///   affine_map<(d0, d1) -> (d1)>,
74 ///   affine_map<(d0, d1) -> (d0)>,
75 ///   affine_map<(d0, d1) -> (d0, d1)>
76 /// ]
77 ///
78 /// #trait = {
79 ///   args_in = 2,
80 ///   args_out = 1,
81 ///   indexing_maps = #accesses,
82 ///   iterator_types = ["parallel", "parallel"],
83 ///   library_call = "some_external_fn"
84 /// }
85 ///
86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
87 /// tensor<5x5xf32>
88 /// {
89 ///   %0 = linalg.generic #trait %arg0, %arg1 {
90 ///        ^bb0(%arg2: f32, %arg3: f32):
91 ///          %3 = arith.addf %arg2, %arg3 : f32
92 ///          linalg.yield %3 : f32
93 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
94 ///   return %0 : tensor<5x5xf32>
95 /// }
96 
97 /// Given dims of the iteration space of a structured op that are known to be
98 /// single trip count (`unitDims`), return the indexing maps to use in the
99 /// canonicalized op with these dims removed, given the original `indexingMaps`.
100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
101                                  ArrayRef<AffineMap> indexingMaps,
102                                  MLIRContext *context) {
103   if (indexingMaps.empty())
104     return nullptr;
105   unsigned numIterationDims = indexingMaps.front().getNumDims();
106   unsigned numSymbols = indexingMaps.front().getNumSymbols();
107 
108   // Compute the replacement for each dim expr.
109   SmallVector<AffineExpr, 4> dimReplacements;
110   dimReplacements.reserve(numIterationDims);
111   unsigned numKeptDims = 0;
112   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
113     if (unitDims.count(dim))
114       dimReplacements.push_back(getAffineConstantExpr(0, context));
115     else
116       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
117   }
118 
119   // Symbols remain the same.
120   SmallVector<AffineExpr, 4> symReplacements;
121   symReplacements.reserve(numSymbols);
122   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
123     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
124 
125   SmallVector<AffineMap, 4> newIndexingMaps;
126   newIndexingMaps.reserve(indexingMaps.size());
127   for (AffineMap operandMap : indexingMaps) {
128     // Expected indexing maps to have no symbols.
129     if (operandMap.getNumSymbols())
130       return nullptr;
131     newIndexingMaps.push_back(simplifyAffineMap(
132         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
133                                          numIterationDims - unitDims.size(),
134                                          numSymbols)));
135   }
136 
137   // Check that the new index maps are invertible. If not, something went
138   // wrong, so abort.
139   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
140     return nullptr;
141   return ArrayAttr::get(context,
142                         llvm::to_vector<4>(llvm::map_range(
143                             newIndexingMaps, [](AffineMap map) -> Attribute {
144                               return AffineMapAttr::get(map);
145                             })));
146 }
147 
148 /// Update the index accesses of linalg operations having index semantics.
149 static void replaceUnitDimIndexOps(GenericOp genericOp,
150                                    const DenseSet<unsigned> &unitDims,
151                                    PatternRewriter &rewriter) {
152   for (IndexOp indexOp :
153        llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
154     OpBuilder::InsertionGuard guard(rewriter);
155     rewriter.setInsertionPoint(indexOp);
156     if (unitDims.count(indexOp.dim()) != 0) {
157       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
158     } else {
159       // Update the dimension of the index operation if needed.
160       unsigned droppedDims = llvm::count_if(
161           unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
162       if (droppedDims != 0)
163         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
164                                              indexOp.dim() - droppedDims);
165     }
166   }
167 }
168 
169 namespace {
170 /// Pattern to fold unit-trip count loops in GenericOps.
171 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
172   using OpRewritePattern<GenericOp>::OpRewritePattern;
173   LogicalResult matchAndRewrite(GenericOp genericOp,
174                                 PatternRewriter &rewriter) const override {
175     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
176     if (indexingMaps.empty())
177       return failure();
178 
179     // Check if any of the iteration dimensions are unit-trip count. They will
180     // end up being unit-trip count if they are used to index into a unit-dim
181     // tensor/memref.
182     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
183     if (!invertedMap)
184       return failure();
185     SmallVector<int64_t> dims = genericOp.getStaticShape();
186 
187     DenseSet<unsigned> unitDims;
188     SmallVector<unsigned, 4> unitDimsReductionLoops;
189     ArrayAttr iteratorTypes = genericOp.iterator_types();
190     for (auto expr : enumerate(invertedMap.getResults())) {
191       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
192         if (dims[dimExpr.getPosition()] == 1)
193           unitDims.insert(expr.index());
194     }
195 
196     if (unitDims.empty())
197       return failure();
198 
199     // Compute the modified indexing maps.
200     MLIRContext *context = rewriter.getContext();
201     ArrayAttr newIndexingMapAttr =
202         replaceUnitDims(unitDims, indexingMaps, context);
203     if (!newIndexingMapAttr)
204       return genericOp.emitError("unable to compute modified indexing_maps");
205 
206     // Compute the iterator types of the modified op by dropping the one-trip
207     // count loops.
208     SmallVector<Attribute, 4> newIteratorTypes;
209     for (auto attr : llvm::enumerate(iteratorTypes)) {
210       if (!unitDims.count(attr.index()))
211         newIteratorTypes.push_back(attr.value());
212     }
213 
214     rewriter.startRootUpdate(genericOp);
215     genericOp.indexing_mapsAttr(newIndexingMapAttr);
216     genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
217     replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
218     rewriter.finalizeRootUpdate(genericOp);
219     return success();
220   }
221 };
222 
223 struct UnitExtentReplacementInfo {
224   Type type;
225   AffineMap indexMap;
226   ArrayAttr reassociation;
227 };
228 } // namespace
229 
230 /// Utility function for replacing operands/results to a linalg generic
231 /// operation with unit-extent dimensions. These can be replaced with
232 /// an operand/result with the unit-extent dimension removed. This is only done
233 /// if the indexing map used to access that didimensionmension has a
234 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
235 /// Linalg op, and its `indexMap` the utility function returns:
236 /// - the new type with dimensions of size 1 removed.
237 /// - modified index map that can be used to access the replaced result/operand
238 /// - the reassociation that converts from the original tensor type to the
239 ///   modified tensor type.
240 static llvm::Optional<UnitExtentReplacementInfo>
241 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
242                    MLIRContext *context) {
243   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
244   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
245   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
246   SmallVector<AffineExpr> reassociations;
247   SmallVector<Attribute> reassociationMaps;
248   SmallVector<AffineExpr> newIndexExprs;
249   SmallVector<int64_t> newShape;
250 
251   int64_t origRank = genericOp.getRank(opOperand);
252   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
253   auto isUnitExtent = [&](int64_t dim) -> bool {
254     return shape[dim] == 1 && exprs[dim] == zeroExpr;
255   };
256 
257   // Early return for memrefs with affine maps to represent that we will always
258   // leave them unchanged.
259   Type actualType = opOperand->get().getType();
260   if (auto memref = actualType.dyn_cast<MemRefType>()) {
261     if (!memref.getLayout().isIdentity())
262       return llvm::None;
263   }
264 
265   int64_t dim = 0;
266   // Fold dimensions that are unit-extent at the beginning of the tensor.
267   while (dim < origRank && isUnitExtent(dim))
268     reassociations.push_back(getAffineDimExpr(dim++, context));
269   while (dim < origRank) {
270     reassociations.push_back(getAffineDimExpr(dim, context));
271     newIndexExprs.push_back(exprs[dim]);
272     newShape.push_back(shape[dim]);
273     // Fold all following dimensions that are unit-extent.
274     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
275       ++dim;
276       reassociations.push_back(getAffineDimExpr(dim, context));
277     }
278     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
279         origRank, /*symbolCount = */ 0, reassociations, context)));
280     reassociations.clear();
281     ++dim;
282   }
283 
284   // Compute the tensor or scalar replacement type.
285   Type elementType = getElementTypeOrSelf(opOperand->get());
286   Type replacementType;
287   if (elementType == opOperand->get().getType()) {
288     replacementType = elementType;
289   } else if (actualType.isa<RankedTensorType>()) {
290     replacementType = RankedTensorType::get(newShape, elementType);
291   } else if (actualType.isa<MemRefType>()) {
292     replacementType = MemRefType::get(newShape, elementType);
293   }
294   assert(replacementType && "unsupported shaped type");
295   UnitExtentReplacementInfo info = {replacementType,
296                                     AffineMap::get(indexingMap.getNumDims(),
297                                                    indexingMap.getNumSymbols(),
298                                                    newIndexExprs, context),
299                                     ArrayAttr::get(context, reassociationMaps)};
300   return info;
301 }
302 
303 namespace {
304 
305 SmallVector<ReassociationExprs, 2>
306 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
307   SmallVector<ReassociationExprs, 2> reassociationExprs;
308   for (auto attr : affineMapArrayAttr)
309     reassociationExprs.push_back(
310         llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
311   return reassociationExprs;
312 }
313 
314 /// Pattern to replace tensor/buffer operands/results that are unit extents.
315 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
316   using OpRewritePattern<GenericOp>::OpRewritePattern;
317 
318   // Return the original value if the type is unchanged, or reshape it. Return a
319   // nullptr if this is an unsupported type.
320   Value maybeExpand(Value result, Type origResultType,
321                     ArrayAttr reassociationMap, Location loc,
322                     PatternRewriter &rewriter) const {
323     if (origResultType == result.getType())
324       return result;
325     if (origResultType.isa<RankedTensorType>()) {
326       return rewriter.create<linalg::TensorExpandShapeOp>(
327           loc, origResultType, result,
328           convertAffineMapArrayToExprs(reassociationMap));
329     }
330     if (origResultType.isa<MemRefType>()) {
331       return rewriter.create<memref::ExpandShapeOp>(
332           loc, origResultType, result,
333           convertAffineMapArrayToExprs(reassociationMap));
334     }
335     return nullptr;
336   };
337 
338   // Return the original value if the type is unchanged, or reshape it. Return a
339   // nullptr if this is an unsupported type.
340   Value maybeCollapse(Value operand, Type newInputOutputType,
341                       ArrayAttr reassociationMap, Location loc,
342                       PatternRewriter &rewriter) const {
343     auto operandType = operand.getType();
344     if (operandType == newInputOutputType)
345       return operand;
346     if (operandType.isa<MemRefType>()) {
347       return rewriter.create<memref::CollapseShapeOp>(
348           loc, newInputOutputType, operand,
349           convertAffineMapArrayToExprs(reassociationMap));
350     }
351     if (operandType.isa<RankedTensorType>()) {
352       return rewriter.create<linalg::TensorCollapseShapeOp>(
353           loc, newInputOutputType, operand,
354           convertAffineMapArrayToExprs(reassociationMap));
355     }
356     return nullptr;
357   };
358 
359   LogicalResult matchAndRewrite(GenericOp genericOp,
360                                 PatternRewriter &rewriter) const override {
361     // Skip the pattern if the op has any tensor with special encoding.
362     if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
363           auto tensorType = type.dyn_cast<RankedTensorType>();
364           return tensorType && tensorType.getEncoding() != nullptr;
365         }))
366       return failure();
367     MLIRContext *context = rewriter.getContext();
368     Location loc = genericOp.getLoc();
369 
370     SmallVector<AffineMap> newIndexingMaps;
371     SmallVector<ArrayAttr> reassociationMaps;
372     SmallVector<Type> newInputOutputTypes;
373     bool doCanonicalization = false;
374     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
375       auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
376       if (replacementInfo) {
377         reassociationMaps.push_back(replacementInfo->reassociation);
378         newIndexingMaps.push_back(replacementInfo->indexMap);
379         newInputOutputTypes.push_back(replacementInfo->type);
380         doCanonicalization |=
381             replacementInfo->type != opOperand->get().getType();
382       } else {
383         // If replaceUnitExtents cannot handle this case, maintain the same
384         // type, indexing map, and create a set of mappings representing an
385         // identity matrix.
386         newInputOutputTypes.push_back(opOperand->get().getType());
387         newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
388         int64_t origRank = genericOp.getRank(opOperand);
389         auto maps = llvm::to_vector<8>(llvm::map_range(
390             llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
391               return AffineMapAttr::get(
392                   AffineMap::get(origRank, /*symbolCount = */ 0,
393                                  getAffineDimExpr(dim, context), context));
394             }));
395         reassociationMaps.push_back(ArrayAttr::get(context, maps));
396       }
397     }
398 
399     // If the indexing maps of the result operation are not invertible (i.e. not
400     // legal), abort.
401     if (!doCanonicalization ||
402         !inversePermutation(concatAffineMaps(newIndexingMaps)))
403       return failure();
404 
405     // If any operand type change, insert a reshape to convert from the original
406     // type to the new type.
407     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
408     unsigned flattenedIdx = 0;
409     auto insertReshapes = [&](ValueRange values) {
410       SmallVector<Value, 4> res;
411       res.reserve(values.size());
412       for (auto operand : values) {
413         auto reshapedValue =
414             maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
415                           reassociationMaps[flattenedIdx], loc, rewriter);
416         assert(reshapedValue &&
417                "expected ranked MemRef or Tensor operand type");
418         res.push_back(reshapedValue);
419         ++flattenedIdx;
420       }
421       return res;
422     };
423 
424     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
425     SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
426 
427     // If any result type changes, insert a reshape to convert from the original
428     // type to the new type.
429     SmallVector<Type, 4> resultTypes;
430     resultTypes.reserve(genericOp.getNumResults());
431     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
432       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
433     GenericOp replacementOp = rewriter.create<GenericOp>(
434         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
435         llvm::to_vector<4>(
436             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
437     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
438                                 replacementOp.region().begin());
439 
440     // If any result tensor has a modified shape, then add reshape to recover
441     // the original shape.
442     SmallVector<Value, 4> resultReplacements;
443     for (auto result : llvm::enumerate(replacementOp.getResults())) {
444       unsigned index = result.index() + replacementOp.getNumInputs();
445       auto origResultType = genericOp.getResult(result.index()).getType();
446 
447       auto newResult = maybeExpand(result.value(), origResultType,
448                                    reassociationMaps[index], loc, rewriter);
449       assert(newResult &&
450              "unexpected output type other than ranked MemRef or Tensor");
451       resultReplacements.push_back(newResult);
452     }
453     rewriter.replaceOp(genericOp, resultReplacements);
454     return success();
455   }
456 };
457 } // namespace
458 
459 /// Get the reassociation maps to fold the result of a extract_slice (or source
460 /// of a insert_slice) operation with given offsets, and sizes to its
461 /// rank-reduced version. This is only done for the cases where the size is 1
462 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
463 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
464 /// potentially cannot be handled).
465 static Optional<SmallVector<ReassociationIndices>>
466 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
467   SmallVector<ReassociationIndices> reassociation;
468   ReassociationIndices curr;
469   for (auto it : llvm::enumerate(mixedSizes)) {
470     auto dim = it.index();
471     auto size = it.value();
472     curr.push_back(dim);
473     auto attr = size.dyn_cast<Attribute>();
474     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
475       continue;
476     reassociation.emplace_back(ReassociationIndices{});
477     std::swap(reassociation.back(), curr);
478   }
479   // When the reassociations are not empty, then fold the remaining
480   // unit-dimensions into the last dimension.  If the reassociations so far is
481   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
482   if (!curr.empty() && !reassociation.empty())
483     reassociation.back().append(curr.begin(), curr.end());
484   return reassociation;
485 }
486 
487 namespace {
488 /// Convert `extract_slice` operations to rank-reduced versions.
489 struct UseRankReducedExtractSliceOp
490     : public OpRewritePattern<tensor::ExtractSliceOp> {
491   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
492 
493   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
494                                 PatternRewriter &rewriter) const override {
495     RankedTensorType resultType = sliceOp.getType();
496     SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
497     SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
498     SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
499     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
500     if (!reassociation ||
501         reassociation->size() == static_cast<size_t>(resultType.getRank()))
502       return failure();
503     auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType(
504                                reassociation->size(), sliceOp.getSourceType(),
505                                offsets, sizes, strides)
506                                .cast<RankedTensorType>();
507 
508     Location loc = sliceOp.getLoc();
509     Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
510         loc, rankReducedType, sliceOp.source(), offsets, sizes, strides);
511     rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(sliceOp, resultType,
512                                                      newSlice, *reassociation);
513     return success();
514   }
515 };
516 
517 /// Convert `insert_slice` operations to rank-reduced versions.
518 struct UseRankReducedInsertSliceOp
519     : public OpRewritePattern<tensor::InsertSliceOp> {
520   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
521 
522   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
523                                 PatternRewriter &rewriter) const override {
524     RankedTensorType sourceType = insertOp.getSourceType();
525     SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
526     SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
527     SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
528     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
529     if (!reassociation ||
530         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
531       return failure();
532     Location loc = insertOp.getLoc();
533     auto reshapedSource = rewriter.create<TensorCollapseShapeOp>(
534         loc, insertOp.source(), *reassociation);
535     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
536         insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
537         insertOp.getMixedSizes(), insertOp.getMixedStrides());
538     return success();
539   }
540 };
541 } // namespace
542 
543 /// Patterns that are used to canonicalize the use of unit-extent dims for
544 /// broadcasting.
545 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
546     RewritePatternSet &patterns) {
547   auto *context = patterns.getContext();
548   patterns.add<FoldUnitDimLoops, ReplaceUnitExtents,
549                UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>(
550       context);
551   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
552   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
553 }
554 
555 namespace {
556 /// Pass that removes unit-extent dims within generic ops.
557 struct LinalgFoldUnitExtentDimsPass
558     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
559   void runOnFunction() override {
560     FuncOp funcOp = getFunction();
561     MLIRContext *context = funcOp.getContext();
562     RewritePatternSet patterns(context);
563     if (foldOneTripLoopsOnly)
564       patterns.add<FoldUnitDimLoops>(context);
565     else
566       populateFoldUnitExtentDimsPatterns(patterns);
567     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
568   }
569 };
570 } // namespace
571 
572 std::unique_ptr<OperationPass<FuncOp>>
573 mlir::createLinalgFoldUnitExtentDimsPass() {
574   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
575 }
576