1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "linalg-drop-unit-dims"
30 
31 using namespace mlir;
32 using namespace mlir::edsc;
33 using namespace mlir::edsc::intrinsics;
34 using namespace mlir::linalg;
35 
36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
37 /// broadcasting. For example,
38 ///
39 /// ```mlir
40 /// #accesses = [
41 ///   affine_map<(d0, d1) -> (0, d1)>,
42 ///   affine_map<(d0, d1) -> (d0, 0)>,
43 ///   affine_map<(d0, d1) -> (d0, d1)>
44 /// ]
45 ///
46 /// #trait = {
47 ///   args_in = 2,
48 ///   args_out = 1,
49 ///   indexing_maps = #accesses,
50 ///   iterator_types = ["parallel", "parallel"],
51 ///   library_call = "some_external_fn"
52 /// }
53 ///
54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
55 /// tensor<5x5xf32>
56 /// {
57 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
58 ///        tensor<5xf32> into tensor<1x5xf32>
59 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
60 ///        tensor<5xf32> into tensor<5x1xf32>
61 ///   %2 = linalg.generic #trait %0, %1 {
62 ///        ^bb0(%arg2: f32, %arg3: f32):
63 ///          %3 = addf %arg2, %arg3 : f32
64 ///          linalg.yield %3 : f32
65 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
66 ///   return %2 : tensor<5x5xf32>
67 /// }
68 ///
69 /// would canonicalize to
70 ///
71 /// ```mlir
72 /// #accesses = [
73 ///   affine_map<(d0, d1) -> (d1)>,
74 ///   affine_map<(d0, d1) -> (d0)>,
75 ///   affine_map<(d0, d1) -> (d0, d1)>
76 /// ]
77 ///
78 /// #trait = {
79 ///   args_in = 2,
80 ///   args_out = 1,
81 ///   indexing_maps = #accesses,
82 ///   iterator_types = ["parallel", "parallel"],
83 ///   library_call = "some_external_fn"
84 /// }
85 ///
86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
87 /// tensor<5x5xf32>
88 /// {
89 ///   %0 = linalg.generic #trait %arg0, %arg1 {
90 ///        ^bb0(%arg2: f32, %arg3: f32):
91 ///          %3 = addf %arg2, %arg3 : f32
92 ///          linalg.yield %3 : f32
93 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
94 ///   return %0 : tensor<5x5xf32>
95 /// }
96 
97 /// Given dims of the iteration space of a structured op that are known to be
98 /// single trip count (`unitDims`), return the indexing maps to use in the
99 /// canonicalized op with these dims removed, given the original `indexingMaps`.
100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
101                                  ArrayRef<AffineMap> indexingMaps,
102                                  MLIRContext *context) {
103   if (indexingMaps.empty())
104     return nullptr;
105   unsigned numIterationDims = indexingMaps.front().getNumDims();
106   unsigned numSymbols = indexingMaps.front().getNumSymbols();
107 
108   // Compute the replacement for each dim expr.
109   SmallVector<AffineExpr, 4> dimReplacements;
110   dimReplacements.reserve(numIterationDims);
111   unsigned numKeptDims = 0;
112   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
113     if (unitDims.count(dim))
114       dimReplacements.push_back(getAffineConstantExpr(0, context));
115     else
116       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
117   }
118 
119   // Symbols remain the same.
120   SmallVector<AffineExpr, 4> symReplacements;
121   symReplacements.reserve(numSymbols);
122   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
123     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
124 
125   SmallVector<AffineMap, 4> newIndexingMaps;
126   newIndexingMaps.reserve(indexingMaps.size());
127   for (AffineMap operandMap : indexingMaps) {
128     // Expected indexing maps to have no symbols.
129     if (operandMap.getNumSymbols())
130       return nullptr;
131     newIndexingMaps.push_back(simplifyAffineMap(
132         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
133                                          numIterationDims - unitDims.size(),
134                                          numSymbols)));
135   }
136 
137   // Check that the new index maps are invertible. If not, something went
138   // wrong, so abort.
139   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
140     return nullptr;
141   return ArrayAttr::get(context,
142                         llvm::to_vector<4>(llvm::map_range(
143                             newIndexingMaps, [](AffineMap map) -> Attribute {
144                               return AffineMapAttr::get(map);
145                             })));
146 }
147 
148 /// Update the index accesses of linalg operations having index semantics.
149 static void replaceUnitDimIndexOps(GenericOp genericOp,
150                                    const DenseSet<unsigned> &unitDims,
151                                    PatternRewriter &rewriter) {
152   assert(genericOp->getNumRegions() == 1 &&
153          genericOp->getRegion(0).getBlocks().size() == 1 &&
154          "expected generic operation to have one block.");
155   Block &block = genericOp->getRegion(0).front();
156 
157   for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
158     OpBuilder::InsertionGuard guard(rewriter);
159     rewriter.setInsertionPoint(indexOp);
160     if (unitDims.count(indexOp.dim()) != 0) {
161       rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0);
162     } else {
163       // Update the dimension of the index operation if needed.
164       unsigned droppedDims = llvm::count_if(
165           unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
166       if (droppedDims != 0)
167         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
168                                              indexOp.dim() - droppedDims);
169     }
170   }
171 }
172 
173 namespace {
174 /// Pattern to fold unit-trip count loops in GenericOps.
175 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
176   using OpRewritePattern<GenericOp>::OpRewritePattern;
177   LogicalResult matchAndRewrite(GenericOp genericOp,
178                                 PatternRewriter &rewriter) const override {
179     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
180     if (indexingMaps.empty())
181       return failure();
182 
183     // Check if any of the iteration dimensions are unit-trip count. They will
184     // end up being unit-trip count if they are used to index into a unit-dim
185     // tensor/memref.
186     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
187     if (!invertedMap)
188       return failure();
189     SmallVector<int64_t, 4> dims;
190     for (ShapedType shapedType : genericOp.getShapedOperandTypes())
191       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
192 
193     // Find all the reduction iterators. Those need some special consideration
194     // (see below).
195     auto getLoopDimsOfType =
196         [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
197       SmallVector<AffineExpr> dimExprs;
198       getDimsOfType(genericOp, iteratorTypeName, dimExprs);
199       return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
200         return expr.cast<AffineDimExpr>().getPosition();
201       }));
202     };
203     auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
204 
205     DenseSet<unsigned> unitDims;
206     SmallVector<unsigned, 4> unitDimsReductionLoops;
207     ArrayAttr iteratorTypes = genericOp.iterator_types();
208     for (auto expr : enumerate(invertedMap.getResults())) {
209       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
210         if (dims[dimExpr.getPosition()] == 1) {
211           if (isParallelIterator(iteratorTypes[expr.index()]))
212             unitDims.insert(expr.index());
213           else if (isReductionIterator(iteratorTypes[expr.index()]))
214             unitDimsReductionLoops.push_back(expr.index());
215         }
216     }
217 
218     // Reduction loops can be dropped if there is at least one other reduction
219     // loop that is not dropped. This accounts for the initial value read in the
220     // reduction loop.
221     if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
222       if (unitDimsReductionLoops.size() == reductionDims.size())
223         unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
224       else
225         unitDims.insert(unitDimsReductionLoops.begin(),
226                         unitDimsReductionLoops.end());
227     }
228 
229     if (unitDims.empty())
230       return failure();
231 
232     // Compute the modified indexing maps.
233     MLIRContext *context = rewriter.getContext();
234     ArrayAttr newIndexingMapAttr =
235         replaceUnitDims(unitDims, indexingMaps, context);
236     if (!newIndexingMapAttr)
237       return genericOp.emitError("unable to compute modified indexing_maps");
238 
239     // Compute the iterator types of the modified op by dropping the one-trip
240     // count loops.
241     SmallVector<Attribute, 4> newIteratorTypes;
242     for (auto attr : llvm::enumerate(iteratorTypes)) {
243       if (!unitDims.count(attr.index()))
244         newIteratorTypes.push_back(attr.value());
245     }
246 
247     rewriter.startRootUpdate(genericOp);
248     genericOp.indexing_mapsAttr(newIndexingMapAttr);
249     genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
250     replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
251     rewriter.finalizeRootUpdate(genericOp);
252     return success();
253   }
254 };
255 
256 struct UnitExtentReplacementInfo {
257   RankedTensorType type;
258   AffineMap indexMap;
259   ArrayAttr reassociation;
260 };
261 } // namespace
262 
263 /// Utility function for replacing operands/results to a linalg generic
264 /// operation on tensors with unit-extent dimensions. These can be replaced with
265 /// an operand/result with the unit-extent dimension removed. This is only done
266 /// if the indexing map used to access that didimensionmension has a
267 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
268 /// Linalg op, and its `indexMap` the utility function returns:
269 /// - the new type with dimensions of size 1 removed.
270 /// - modified index map that can be used to access the replaced result/operand
271 /// - the reassociation that converts from the original tensor type to the
272 ///   modified tensor type.
273 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
274                                                     RankedTensorType type,
275                                                     MLIRContext *context) {
276   ArrayRef<int64_t> shape = type.getShape();
277   ArrayRef<AffineExpr> exprs = indexMap.getResults();
278   SmallVector<AffineExpr, 2> reassociations;
279   SmallVector<Attribute, 4> reassociationMaps;
280   SmallVector<AffineExpr, 4> newIndexExprs;
281   SmallVector<int64_t, 4> newShape;
282 
283   int64_t origRank = type.getRank();
284   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
285   auto isUnitExtent = [&](int64_t dim) -> bool {
286     return shape[dim] == 1 && exprs[dim] == zeroExpr;
287   };
288 
289   unsigned dim = 0;
290   // Fold dimensions that are unit-extent at the beginning of the tensor.
291   while (dim < origRank && isUnitExtent(dim))
292     reassociations.push_back(getAffineDimExpr(dim++, context));
293   while (dim < origRank) {
294     reassociations.push_back(getAffineDimExpr(dim, context));
295     newIndexExprs.push_back(exprs[dim]);
296     newShape.push_back(shape[dim]);
297     // Fold all following dimensions that are unit-extent.
298     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
299       ++dim;
300       reassociations.push_back(getAffineDimExpr(dim, context));
301     }
302     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
303         origRank, /*symbolCount = */ 0, reassociations, context)));
304     reassociations.clear();
305     ++dim;
306   }
307   UnitExtentReplacementInfo info = {
308       RankedTensorType::get(newShape, type.getElementType()),
309       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
310                      newIndexExprs, context),
311       ArrayAttr::get(context, reassociationMaps)};
312   return info;
313 }
314 
315 namespace {
316 
317 SmallVector<ReassociationExprs, 2>
318 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
319   SmallVector<ReassociationExprs, 2> reassociationExprs;
320   for (auto attr : affineMapArrayAttr)
321     reassociationExprs.push_back(
322         llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
323   return reassociationExprs;
324 }
325 
326 /// Pattern to replace tensors operands/results that are unit extents.
327 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
328   using OpRewritePattern<GenericOp>::OpRewritePattern;
329   LogicalResult matchAndRewrite(GenericOp genericOp,
330                                 PatternRewriter &rewriter) const override {
331     if (!genericOp.hasTensorSemantics())
332       return failure();
333 
334     MLIRContext *context = rewriter.getContext();
335     Location loc = genericOp.getLoc();
336 
337     SmallVector<AffineMap, 4> newIndexingMaps;
338     SmallVector<ArrayAttr, 4> reassociationMaps;
339     SmallVector<ShapedType, 4> newInputOutputTypes;
340     bool doCanonicalization = false;
341     for (auto it : llvm::zip(genericOp.getIndexingMaps(),
342                              genericOp.getShapedOperandTypes())) {
343       auto replacementInfo = replaceUnitExtents(
344           std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
345           context);
346       reassociationMaps.push_back(replacementInfo.reassociation);
347       newIndexingMaps.push_back(replacementInfo.indexMap);
348       newInputOutputTypes.push_back(replacementInfo.type);
349       doCanonicalization |= replacementInfo.type != std::get<1>(it);
350     }
351 
352     // If the indexing maps of the result operation are not invertible (i.e. not
353     // legal), abort.
354     if (!doCanonicalization ||
355         !inversePermutation(concatAffineMaps(newIndexingMaps)))
356       return failure();
357 
358     // If any operand type change, insert a reshape to convert from the original
359     // type to the new type.
360     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
361     unsigned flattenedIdx = 0;
362     auto insertReshapes = [&](ValueRange values) {
363       SmallVector<Value, 4> res;
364       res.reserve(values.size());
365       for (auto operand : llvm::enumerate(values)) {
366         if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
367           res.push_back(operand.value());
368         else
369           res.push_back(rewriter.create<linalg::TensorReshapeOp>(
370               loc, newInputOutputTypes[flattenedIdx], operand.value(),
371               convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
372         ++flattenedIdx;
373       }
374       return res;
375     };
376 
377     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
378     SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
379 
380     // If any result type changes, insert a reshape to convert from the original
381     // type to the new type.
382     SmallVector<Type, 4> resultTypes;
383     resultTypes.reserve(genericOp.getNumResults());
384     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
385       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
386     GenericOp replacementOp = rewriter.create<GenericOp>(
387         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
388         llvm::to_vector<4>(
389             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
390     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
391                                 replacementOp.region().begin());
392 
393     // If any result tensor has a modified shape, then add reshape to recover
394     // the original shape.
395     SmallVector<Value, 4> resultReplacements;
396     for (auto result : llvm::enumerate(replacementOp.getResults())) {
397       unsigned index = result.index() + replacementOp.getNumInputs();
398       RankedTensorType origResultType = genericOp.getResult(result.index())
399                                             .getType()
400                                             .template cast<RankedTensorType>();
401       if (origResultType != result.value().getType())
402         resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
403             loc, origResultType, result.value(),
404             convertAffineMapArrayToExprs(reassociationMaps[index])));
405       else
406         resultReplacements.push_back(result.value());
407     }
408     rewriter.replaceOp(genericOp, resultReplacements);
409     return success();
410   }
411 };
412 } // namespace
413 
414 /// Get the reassociation maps to fold the result of a subtensor (or source of a
415 /// subtensor_insert) operation with given offsets, and sizes to its
416 /// rank-reduced version. This is only done for the cases where the size is 1
417 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
418 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
419 /// potentially cannot be handled).
420 static Optional<SmallVector<ReassociationIndices>>
421 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
422   SmallVector<ReassociationIndices> reassociation;
423   ReassociationIndices curr;
424   for (auto it : llvm::enumerate(mixedSizes)) {
425     auto dim = it.index();
426     auto size = it.value();
427     curr.push_back(dim);
428     auto attr = size.dyn_cast<Attribute>();
429     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
430       continue;
431     reassociation.emplace_back(ReassociationIndices{});
432     std::swap(reassociation.back(), curr);
433   }
434   // When the reassociations are not empty, then fold the remaining
435   // unit-dimensions into the last dimension.  If the reassociations so far is
436   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
437   if (!curr.empty() && !reassociation.empty())
438     reassociation.back().append(curr.begin(), curr.end());
439   return reassociation;
440 }
441 
442 namespace {
443 /// Convert `subtensor` operations to rank-reduced versions.
444 struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
445   using OpRewritePattern<SubTensorOp>::OpRewritePattern;
446 
447   LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
448                                 PatternRewriter &rewriter) const override {
449     RankedTensorType resultType = subTensorOp.getType();
450     SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets();
451     SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes();
452     SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides();
453     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
454     if (!reassociation ||
455         reassociation->size() == static_cast<size_t>(resultType.getRank()))
456       return failure();
457     auto rankReducedType =
458         SubTensorOp::inferRankReducedResultType(reassociation->size(),
459                                                 subTensorOp.getSourceType(),
460                                                 offsets, sizes, strides)
461             .cast<RankedTensorType>();
462 
463     Location loc = subTensorOp.getLoc();
464     Value newSubTensor = rewriter.create<SubTensorOp>(
465         loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
466     rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
467                                                  newSubTensor, *reassociation);
468     return success();
469   }
470 };
471 
472 /// Convert `subtensor_insert` operations to rank-reduced versions.
473 struct UseRankReducedSubTensorInsertOp
474     : public OpRewritePattern<SubTensorInsertOp> {
475   using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
476 
477   LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
478                                 PatternRewriter &rewriter) const override {
479     RankedTensorType sourceType = insertOp.getSourceType();
480     SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
481     SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
482     SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
483     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
484     if (!reassociation ||
485         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
486       return failure();
487     Location loc = insertOp.getLoc();
488     auto reshapedSource = rewriter.create<TensorReshapeOp>(
489         loc, insertOp.source(), *reassociation);
490     rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
491         insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
492         insertOp.getMixedSizes(), insertOp.getMixedStrides());
493     return success();
494   }
495 };
496 } // namespace
497 
498 /// Patterns that are used to canonicalize the use of unit-extent dims for
499 /// broadcasting.
500 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
501     RewritePatternSet &patterns) {
502   auto *context = patterns.getContext();
503   patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
504                UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
505       context);
506   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
507 }
508 
509 namespace {
510 /// Pass that removes unit-extent dims within generic ops.
511 struct LinalgFoldUnitExtentDimsPass
512     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
513   void runOnFunction() override {
514     FuncOp funcOp = getFunction();
515     MLIRContext *context = funcOp.getContext();
516     RewritePatternSet patterns(context);
517     if (foldOneTripLoopsOnly)
518       patterns.add<FoldUnitDimLoops>(context);
519     else
520       populateFoldUnitExtentDimsPatterns(patterns);
521     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
522   }
523 };
524 } // namespace
525 
526 std::unique_ptr<OperationPass<FuncOp>>
527 mlir::createLinalgFoldUnitExtentDimsPass() {
528   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
529 }
530