1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/Transforms/FoldUtils.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "linalg-drop-unit-dims"
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
36 /// broadcasting. For example,
37 ///
38 /// ```mlir
39 /// #accesses = [
40 ///   affine_map<(d0, d1) -> (0, d1)>,
41 ///   affine_map<(d0, d1) -> (d0, 0)>,
42 ///   affine_map<(d0, d1) -> (d0, d1)>
43 /// ]
44 ///
45 /// #trait = {
46 ///   args_in = 2,
47 ///   args_out = 1,
48 ///   indexing_maps = #accesses,
49 ///   iterator_types = ["parallel", "parallel"],
50 ///   library_call = "some_external_fn"
51 /// }
52 ///
53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
54 /// tensor<5x5xf32>
55 /// {
56 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
57 ///        tensor<5xf32> into tensor<1x5xf32>
58 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
59 ///        tensor<5xf32> into tensor<5x1xf32>
60 ///   %2 = linalg.generic #trait %0, %1 {
61 ///        ^bb0(%arg2: f32, %arg3: f32):
62 ///          %3 = addf %arg2, %arg3 : f32
63 ///          linalg.yield %3 : f32
64 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
65 ///   return %2 : tensor<5x5xf32>
66 /// }
67 ///
68 /// would canonicalize to
69 ///
70 /// ```mlir
71 /// #accesses = [
72 ///   affine_map<(d0, d1) -> (d1)>,
73 ///   affine_map<(d0, d1) -> (d0)>,
74 ///   affine_map<(d0, d1) -> (d0, d1)>
75 /// ]
76 ///
77 /// #trait = {
78 ///   args_in = 2,
79 ///   args_out = 1,
80 ///   indexing_maps = #accesses,
81 ///   iterator_types = ["parallel", "parallel"],
82 ///   library_call = "some_external_fn"
83 /// }
84 ///
85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
86 /// tensor<5x5xf32>
87 /// {
88 ///   %0 = linalg.generic #trait %arg0, %arg1 {
89 ///        ^bb0(%arg2: f32, %arg3: f32):
90 ///          %3 = addf %arg2, %arg3 : f32
91 ///          linalg.yield %3 : f32
92 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
93 ///   return %0 : tensor<5x5xf32>
94 /// }
95 
96 /// Given dims of the iteration space of a structured op that are known to be
97 /// single trip count (`unitDims`), return the indexing maps to use in the
98 /// canonicalized op with these dims removed, given the original `indexingMaps`.
99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
100                                  ArrayRef<AffineMap> indexingMaps,
101                                  MLIRContext *context) {
102   if (indexingMaps.empty())
103     return nullptr;
104   unsigned numIterationDims = indexingMaps.front().getNumDims();
105   unsigned numSymbols = indexingMaps.front().getNumSymbols();
106 
107   // Compute the replacement for each dim expr.
108   SmallVector<AffineExpr, 4> dimReplacements;
109   dimReplacements.reserve(numIterationDims);
110   unsigned numKeptDims = 0;
111   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
112     if (unitDims.count(dim))
113       dimReplacements.push_back(getAffineConstantExpr(0, context));
114     else
115       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
116   }
117 
118   // Symbols remain the same.
119   SmallVector<AffineExpr, 4> symReplacements;
120   symReplacements.reserve(numSymbols);
121   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
122     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
123 
124   SmallVector<AffineMap, 4> newIndexingMaps;
125   newIndexingMaps.reserve(indexingMaps.size());
126   for (AffineMap operandMap : indexingMaps) {
127     // Expected indexing maps to have no symbols.
128     if (operandMap.getNumSymbols())
129       return nullptr;
130     newIndexingMaps.push_back(simplifyAffineMap(
131         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
132                                          numIterationDims - unitDims.size(),
133                                          numSymbols)));
134   }
135 
136   // Check that the new index maps are invertible. If not, something went
137   // wrong, so abort.
138   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
139     return nullptr;
140   return ArrayAttr::get(context,
141                         llvm::to_vector<4>(llvm::map_range(
142                             newIndexingMaps, [](AffineMap map) -> Attribute {
143                               return AffineMapAttr::get(map);
144                             })));
145 }
146 
147 /// Update the index accesses of linalg operations having index semantics.
148 static void replaceUnitDimIndexOps(GenericOp genericOp,
149                                    const DenseSet<unsigned> &unitDims,
150                                    PatternRewriter &rewriter) {
151   assert(genericOp->getNumRegions() == 1 &&
152          genericOp->getRegion(0).getBlocks().size() == 1 &&
153          "expected generic operation to have one block.");
154   Block &block = genericOp->getRegion(0).front();
155 
156   for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
157     OpBuilder::InsertionGuard guard(rewriter);
158     rewriter.setInsertionPoint(indexOp);
159     if (unitDims.count(indexOp.dim()) != 0) {
160       rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0);
161     } else {
162       // Update the dimension of the index operation if needed.
163       unsigned droppedDims = llvm::count_if(
164           unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
165       if (droppedDims != 0)
166         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
167                                              indexOp.dim() - droppedDims);
168     }
169   }
170 }
171 
172 namespace {
173 /// Pattern to fold unit-trip count loops in GenericOps.
174 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
175   using OpRewritePattern<GenericOp>::OpRewritePattern;
176   LogicalResult matchAndRewrite(GenericOp genericOp,
177                                 PatternRewriter &rewriter) const override {
178     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
179     if (indexingMaps.empty())
180       return failure();
181 
182     // Check if any of the iteration dimensions are unit-trip count. They will
183     // end up being unit-trip count if they are used to index into a unit-dim
184     // tensor/memref.
185     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
186     if (!invertedMap)
187       return failure();
188     SmallVector<int64_t> dims = genericOp.getStaticShape();
189 
190     // Find all the reduction iterators. Those need some special consideration
191     // (see below).
192     auto getLoopDimsOfType =
193         [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
194       SmallVector<AffineExpr> dimExprs;
195       getDimsOfType(genericOp, iteratorTypeName, dimExprs);
196       return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
197         return expr.cast<AffineDimExpr>().getPosition();
198       }));
199     };
200     auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
201 
202     DenseSet<unsigned> unitDims;
203     SmallVector<unsigned, 4> unitDimsReductionLoops;
204     ArrayAttr iteratorTypes = genericOp.iterator_types();
205     for (auto expr : enumerate(invertedMap.getResults())) {
206       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
207         if (dims[dimExpr.getPosition()] == 1) {
208           if (isParallelIterator(iteratorTypes[expr.index()]))
209             unitDims.insert(expr.index());
210           else if (isReductionIterator(iteratorTypes[expr.index()]))
211             unitDimsReductionLoops.push_back(expr.index());
212         }
213     }
214 
215     // Reduction loops can be dropped if there is at least one other reduction
216     // loop that is not dropped. This accounts for the initial value read in the
217     // reduction loop.
218     if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
219       if (unitDimsReductionLoops.size() == reductionDims.size())
220         unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
221       else
222         unitDims.insert(unitDimsReductionLoops.begin(),
223                         unitDimsReductionLoops.end());
224     }
225 
226     if (unitDims.empty())
227       return failure();
228 
229     // Compute the modified indexing maps.
230     MLIRContext *context = rewriter.getContext();
231     ArrayAttr newIndexingMapAttr =
232         replaceUnitDims(unitDims, indexingMaps, context);
233     if (!newIndexingMapAttr)
234       return genericOp.emitError("unable to compute modified indexing_maps");
235 
236     // Compute the iterator types of the modified op by dropping the one-trip
237     // count loops.
238     SmallVector<Attribute, 4> newIteratorTypes;
239     for (auto attr : llvm::enumerate(iteratorTypes)) {
240       if (!unitDims.count(attr.index()))
241         newIteratorTypes.push_back(attr.value());
242     }
243 
244     rewriter.startRootUpdate(genericOp);
245     genericOp.indexing_mapsAttr(newIndexingMapAttr);
246     genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
247     replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
248     rewriter.finalizeRootUpdate(genericOp);
249     return success();
250   }
251 };
252 
253 struct UnitExtentReplacementInfo {
254   Type type;
255   AffineMap indexMap;
256   ArrayAttr reassociation;
257 };
258 } // namespace
259 
260 /// Utility function for replacing operands/results to a linalg generic
261 /// operation with unit-extent dimensions. These can be replaced with
262 /// an operand/result with the unit-extent dimension removed. This is only done
263 /// if the indexing map used to access that didimensionmension has a
264 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
265 /// Linalg op, and its `indexMap` the utility function returns:
266 /// - the new type with dimensions of size 1 removed.
267 /// - modified index map that can be used to access the replaced result/operand
268 /// - the reassociation that converts from the original tensor type to the
269 ///   modified tensor type.
270 static llvm::Optional<UnitExtentReplacementInfo>
271 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
272                    MLIRContext *context) {
273   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
274   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
275   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
276   SmallVector<AffineExpr> reassociations;
277   SmallVector<Attribute> reassociationMaps;
278   SmallVector<AffineExpr> newIndexExprs;
279   SmallVector<int64_t> newShape;
280 
281   int64_t origRank = genericOp.getRank(opOperand);
282   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
283   auto isUnitExtent = [&](int64_t dim) -> bool {
284     return shape[dim] == 1 && exprs[dim] == zeroExpr;
285   };
286 
287   // Early return for memrefs with affine maps to represent that we will always
288   // leave them unchanged.
289   Type actualType = opOperand->get().getType();
290   if (auto memref = actualType.dyn_cast<MemRefType>()) {
291     if (!memref.getAffineMaps().empty())
292       return llvm::None;
293   }
294 
295   int64_t dim = 0;
296   // Fold dimensions that are unit-extent at the beginning of the tensor.
297   while (dim < origRank && isUnitExtent(dim))
298     reassociations.push_back(getAffineDimExpr(dim++, context));
299   while (dim < origRank) {
300     reassociations.push_back(getAffineDimExpr(dim, context));
301     newIndexExprs.push_back(exprs[dim]);
302     newShape.push_back(shape[dim]);
303     // Fold all following dimensions that are unit-extent.
304     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
305       ++dim;
306       reassociations.push_back(getAffineDimExpr(dim, context));
307     }
308     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
309         origRank, /*symbolCount = */ 0, reassociations, context)));
310     reassociations.clear();
311     ++dim;
312   }
313 
314   // Compute the tensor or scalar replacement type.
315   Type elementType = getElementTypeOrSelf(opOperand->get());
316   Type replacementType;
317   if (elementType == opOperand->get().getType()) {
318     replacementType = elementType;
319   } else if (actualType.isa<RankedTensorType>()) {
320     replacementType = RankedTensorType::get(newShape, elementType);
321   } else if (actualType.isa<MemRefType>()) {
322     replacementType = MemRefType::get(newShape, elementType);
323   }
324   assert(replacementType && "unsupported shaped type");
325   UnitExtentReplacementInfo info = {replacementType,
326                                     AffineMap::get(indexingMap.getNumDims(),
327                                                    indexingMap.getNumSymbols(),
328                                                    newIndexExprs, context),
329                                     ArrayAttr::get(context, reassociationMaps)};
330   return info;
331 }
332 
333 namespace {
334 
335 SmallVector<ReassociationExprs, 2>
336 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
337   SmallVector<ReassociationExprs, 2> reassociationExprs;
338   for (auto attr : affineMapArrayAttr)
339     reassociationExprs.push_back(
340         llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
341   return reassociationExprs;
342 }
343 
344 /// Pattern to replace tensor/buffer operands/results that are unit extents.
345 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
346   using OpRewritePattern<GenericOp>::OpRewritePattern;
347 
348   // Return the original value if the type is unchanged, or reshape it. Return a
349   // nullptr if this is an unsupported type.
350   Value maybeExpand(Value result, Type origResultType,
351                     ArrayAttr reassociationMap, Location loc,
352                     PatternRewriter &rewriter) const {
353     if (origResultType == result.getType())
354       return result;
355     if (origResultType.isa<RankedTensorType>()) {
356       return rewriter.create<linalg::TensorExpandShapeOp>(
357           loc, origResultType, result,
358           convertAffineMapArrayToExprs(reassociationMap));
359     }
360     if (origResultType.isa<MemRefType>()) {
361       return rewriter.create<memref::ExpandShapeOp>(
362           loc, origResultType, result,
363           convertAffineMapArrayToExprs(reassociationMap));
364     }
365     return nullptr;
366   };
367 
368   // Return the original value if the type is unchanged, or reshape it. Return a
369   // nullptr if this is an unsupported type.
370   Value maybeCollapse(Value operand, Type newInputOutputType,
371                       ArrayAttr reassociationMap, Location loc,
372                       PatternRewriter &rewriter) const {
373     auto operandType = operand.getType();
374     if (operandType == newInputOutputType)
375       return operand;
376     if (operandType.isa<MemRefType>()) {
377       return rewriter.create<memref::CollapseShapeOp>(
378           loc, newInputOutputType, operand,
379           convertAffineMapArrayToExprs(reassociationMap));
380     }
381     if (operandType.isa<RankedTensorType>()) {
382       return rewriter.create<linalg::TensorCollapseShapeOp>(
383           loc, newInputOutputType, operand,
384           convertAffineMapArrayToExprs(reassociationMap));
385     }
386     return nullptr;
387   };
388 
389   LogicalResult matchAndRewrite(GenericOp genericOp,
390                                 PatternRewriter &rewriter) const override {
391     MLIRContext *context = rewriter.getContext();
392     Location loc = genericOp.getLoc();
393 
394     SmallVector<AffineMap> newIndexingMaps;
395     SmallVector<ArrayAttr> reassociationMaps;
396     SmallVector<Type> newInputOutputTypes;
397     bool doCanonicalization = false;
398     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
399       auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
400       if (replacementInfo) {
401         reassociationMaps.push_back(replacementInfo->reassociation);
402         newIndexingMaps.push_back(replacementInfo->indexMap);
403         newInputOutputTypes.push_back(replacementInfo->type);
404         doCanonicalization |=
405             replacementInfo->type != opOperand->get().getType();
406       } else {
407         // If replaceUnitExtents cannot handle this case, maintain the same
408         // type, indexing map, and create a set of mappings representing an
409         // identity matrix.
410         newInputOutputTypes.push_back(opOperand->get().getType());
411         newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
412         int64_t origRank = genericOp.getRank(opOperand);
413         auto maps = llvm::to_vector<8>(llvm::map_range(
414             llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
415               return AffineMapAttr::get(
416                   AffineMap::get(origRank, /*symbolCount = */ 0,
417                                  getAffineDimExpr(dim, context), context));
418             }));
419         reassociationMaps.push_back(ArrayAttr::get(context, maps));
420       }
421     }
422 
423     // If the indexing maps of the result operation are not invertible (i.e. not
424     // legal), abort.
425     if (!doCanonicalization ||
426         !inversePermutation(concatAffineMaps(newIndexingMaps)))
427       return failure();
428 
429     // If any operand type change, insert a reshape to convert from the original
430     // type to the new type.
431     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
432     unsigned flattenedIdx = 0;
433     auto insertReshapes = [&](ValueRange values) {
434       SmallVector<Value, 4> res;
435       res.reserve(values.size());
436       for (auto operand : values) {
437         auto reshapedValue =
438             maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
439                           reassociationMaps[flattenedIdx], loc, rewriter);
440         assert(reshapedValue &&
441                "expected ranked MemRef or Tensor operand type");
442         res.push_back(reshapedValue);
443         ++flattenedIdx;
444       }
445       return res;
446     };
447 
448     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
449     SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
450 
451     // If any result type changes, insert a reshape to convert from the original
452     // type to the new type.
453     SmallVector<Type, 4> resultTypes;
454     resultTypes.reserve(genericOp.getNumResults());
455     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
456       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
457     GenericOp replacementOp = rewriter.create<GenericOp>(
458         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
459         llvm::to_vector<4>(
460             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
461     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
462                                 replacementOp.region().begin());
463 
464     // If any result tensor has a modified shape, then add reshape to recover
465     // the original shape.
466     SmallVector<Value, 4> resultReplacements;
467     for (auto result : llvm::enumerate(replacementOp.getResults())) {
468       unsigned index = result.index() + replacementOp.getNumInputs();
469       auto origResultType = genericOp.getResult(result.index()).getType();
470 
471       auto newResult = maybeExpand(result.value(), origResultType,
472                                    reassociationMaps[index], loc, rewriter);
473       assert(newResult &&
474              "unexpected output type other than ranked MemRef or Tensor");
475       resultReplacements.push_back(newResult);
476     }
477     rewriter.replaceOp(genericOp, resultReplacements);
478     return success();
479   }
480 };
481 } // namespace
482 
483 /// Get the reassociation maps to fold the result of a extract_slice (or source
484 /// of a insert_slice) operation with given offsets, and sizes to its
485 /// rank-reduced version. This is only done for the cases where the size is 1
486 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
487 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
488 /// potentially cannot be handled).
489 static Optional<SmallVector<ReassociationIndices>>
490 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
491   SmallVector<ReassociationIndices> reassociation;
492   ReassociationIndices curr;
493   for (auto it : llvm::enumerate(mixedSizes)) {
494     auto dim = it.index();
495     auto size = it.value();
496     curr.push_back(dim);
497     auto attr = size.dyn_cast<Attribute>();
498     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
499       continue;
500     reassociation.emplace_back(ReassociationIndices{});
501     std::swap(reassociation.back(), curr);
502   }
503   // When the reassociations are not empty, then fold the remaining
504   // unit-dimensions into the last dimension.  If the reassociations so far is
505   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
506   if (!curr.empty() && !reassociation.empty())
507     reassociation.back().append(curr.begin(), curr.end());
508   return reassociation;
509 }
510 
511 namespace {
512 /// Convert `extract_slice` operations to rank-reduced versions.
513 struct UseRankReducedExtractSliceOp
514     : public OpRewritePattern<tensor::ExtractSliceOp> {
515   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
516 
517   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
518                                 PatternRewriter &rewriter) const override {
519     RankedTensorType resultType = sliceOp.getType();
520     SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
521     SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
522     SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
523     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
524     if (!reassociation ||
525         reassociation->size() == static_cast<size_t>(resultType.getRank()))
526       return failure();
527     auto rankReducedType = tensor::ExtractSliceOp::inferRankReducedResultType(
528                                reassociation->size(), sliceOp.getSourceType(),
529                                offsets, sizes, strides)
530                                .cast<RankedTensorType>();
531 
532     Location loc = sliceOp.getLoc();
533     Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
534         loc, rankReducedType, sliceOp.source(), offsets, sizes, strides);
535     rewriter.replaceOpWithNewOp<TensorExpandShapeOp>(sliceOp, resultType,
536                                                      newSlice, *reassociation);
537     return success();
538   }
539 };
540 
541 /// Convert `insert_slice` operations to rank-reduced versions.
542 struct UseRankReducedInsertSliceOp
543     : public OpRewritePattern<tensor::InsertSliceOp> {
544   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
545 
546   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
547                                 PatternRewriter &rewriter) const override {
548     RankedTensorType sourceType = insertOp.getSourceType();
549     SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
550     SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
551     SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
552     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
553     if (!reassociation ||
554         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
555       return failure();
556     Location loc = insertOp.getLoc();
557     auto reshapedSource = rewriter.create<TensorCollapseShapeOp>(
558         loc, insertOp.source(), *reassociation);
559     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
560         insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
561         insertOp.getMixedSizes(), insertOp.getMixedStrides());
562     return success();
563   }
564 };
565 } // namespace
566 
567 /// Patterns that are used to canonicalize the use of unit-extent dims for
568 /// broadcasting.
569 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
570     RewritePatternSet &patterns) {
571   auto *context = patterns.getContext();
572   patterns.add<FoldUnitDimLoops, ReplaceUnitExtents,
573                UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>(
574       context);
575   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
576   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
577 }
578 
579 namespace {
580 /// Pass that removes unit-extent dims within generic ops.
581 struct LinalgFoldUnitExtentDimsPass
582     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
583   void runOnFunction() override {
584     FuncOp funcOp = getFunction();
585     MLIRContext *context = funcOp.getContext();
586     RewritePatternSet patterns(context);
587     if (foldOneTripLoopsOnly)
588       patterns.add<FoldUnitDimLoops>(context);
589     else
590       populateFoldUnitExtentDimsPatterns(patterns);
591     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
592   }
593 };
594 } // namespace
595 
596 std::unique_ptr<OperationPass<FuncOp>>
597 mlir::createLinalgFoldUnitExtentDimsPass() {
598   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
599 }
600