1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 
28 #include <set>
29 
30 #define DEBUG_TYPE "linalg-drop-unit-dims"
31 
32 using namespace mlir;
33 using namespace mlir::edsc;
34 using namespace mlir::edsc::intrinsics;
35 using namespace mlir::linalg;
36 
37 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
38 /// broadcasting. For example,
39 ///
40 /// ```mlir
41 /// #accesses = [
42 ///   affine_map<(d0, d1) -> (0, d1)>,
43 ///   affine_map<(d0, d1) -> (d0, 0)>,
44 ///   affine_map<(d0, d1) -> (d0, d1)>
45 /// ]
46 ///
47 /// #trait = {
48 ///   args_in = 2,
49 ///   args_out = 1,
50 ///   indexing_maps = #accesses,
51 ///   iterator_types = ["parallel", "parallel"],
52 ///   library_call = "some_external_fn"
53 /// }
54 ///
55 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
56 /// tensor<5x5xf32>
57 /// {
58 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
59 ///        tensor<5xf32> into tensor<1x5xf32>
60 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
61 ///        tensor<5xf32> into tensor<5x1xf32>
62 ///   %2 = linalg.generic #trait %0, %1 {
63 ///        ^bb0(%arg2: f32, %arg3: f32):
64 ///          %3 = addf %arg2, %arg3 : f32
65 ///          linalg.yield %3 : f32
66 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
67 ///   return %2 : tensor<5x5xf32>
68 /// }
69 ///
70 /// would canonicalize to
71 ///
72 /// ```mlir
73 /// #accesses = [
74 ///   affine_map<(d0, d1) -> (d1)>,
75 ///   affine_map<(d0, d1) -> (d0)>,
76 ///   affine_map<(d0, d1) -> (d0, d1)>
77 /// ]
78 ///
79 /// #trait = {
80 ///   args_in = 2,
81 ///   args_out = 1,
82 ///   indexing_maps = #accesses,
83 ///   iterator_types = ["parallel", "parallel"],
84 ///   library_call = "some_external_fn"
85 /// }
86 ///
87 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
88 /// tensor<5x5xf32>
89 /// {
90 ///   %0 = linalg.generic #trait %arg0, %arg1 {
91 ///        ^bb0(%arg2: f32, %arg3: f32):
92 ///          %3 = addf %arg2, %arg3 : f32
93 ///          linalg.yield %3 : f32
94 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
95 ///   return %0 : tensor<5x5xf32>
96 /// }
97 
98 /// Given dims of the iteration space of a structured op that are known to be
99 /// single trip count (`unitDims`), return the indexing maps to use in the
100 /// canonicalized op with these dims removed, given the original `indexingMaps`.
101 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
102                                  ArrayRef<AffineMap> indexingMaps,
103                                  MLIRContext *context) {
104   if (indexingMaps.empty())
105     return nullptr;
106   unsigned numIterationDims = indexingMaps.front().getNumDims();
107   unsigned numSymbols = indexingMaps.front().getNumSymbols();
108 
109   // Compute the replacement for each dim expr.
110   SmallVector<AffineExpr, 4> dimReplacements;
111   dimReplacements.reserve(numIterationDims);
112   unsigned numKeptDims = 0;
113   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
114     if (unitDims.count(dim))
115       dimReplacements.push_back(getAffineConstantExpr(0, context));
116     else
117       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
118   }
119 
120   // Symbols remain the same.
121   SmallVector<AffineExpr, 4> symReplacements;
122   symReplacements.reserve(numSymbols);
123   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
124     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
125 
126   SmallVector<AffineMap, 4> newIndexingMaps;
127   newIndexingMaps.reserve(indexingMaps.size());
128   for (AffineMap operandMap : indexingMaps) {
129     // Expected indexing maps to have no symbols.
130     if (operandMap.getNumSymbols())
131       return nullptr;
132     newIndexingMaps.push_back(simplifyAffineMap(
133         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
134                                          numIterationDims - unitDims.size(),
135                                          numSymbols)));
136   }
137 
138   // Check that the new index maps are invertible. If not, something went
139   // wrong, so abort.
140   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
141     return nullptr;
142   return ArrayAttr::get(
143       llvm::to_vector<4>(llvm::map_range(
144           newIndexingMaps,
145           [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
146       context);
147 }
148 
149 /// Modify the region of indexed generic op to drop arguments corresponding to
150 /// loops that are unit trip count.
151 template <typename OpTy>
152 static LogicalResult
153 replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
154                                PatternRewriter &rewriterp) {
155   return success();
156 }
157 
158 template <>
159 LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
160     IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
161     PatternRewriter &rewriter) {
162   OpBuilder::InsertionGuard guard(rewriter);
163   Block *entryBlock = &op.getOperation()->getRegion(0).front();
164   rewriter.setInsertionPointToStart(entryBlock);
165   Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
166   for (unsigned unitDimLoop : unitDims) {
167     entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
168   }
169   std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end());
170   for (unsigned i : llvm::reverse(orderedUnitDims))
171     entryBlock->eraseArgument(i);
172   return success();
173 }
174 
175 namespace {
176 /// Pattern to fold unit-trip count loops in GenericOps.
177 // TODO: Generalize this to indexed-generic as well by modifying the region args
178 // as well.
179 template <typename GenericOpTy>
180 struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
181   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
182   LogicalResult matchAndRewrite(GenericOpTy op,
183                                 PatternRewriter &rewriter) const override {
184     SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
185     if (indexingMaps.empty())
186       return failure();
187 
188     // Check if any of the iteration dimensions are unit-trip count. They will
189     // end up being unit-trip count if they are used to index into a unit-dim
190     // tensor/memref.
191     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
192     if (!invertedMap)
193       return failure();
194     SmallVector<int64_t, 4> dims;
195     for (ShapedType shapedType : op.getInputOutputShapedTypes())
196       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
197     DenseSet<unsigned> unitDims;
198     ArrayAttr iteratorTypes = op.iterator_types();
199     for (auto expr : enumerate(invertedMap.getResults())) {
200       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
201         if (dims[dimExpr.getPosition()] == 1 &&
202             iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
203                 getParallelIteratorTypeName())
204           unitDims.insert(expr.index());
205     }
206     if (unitDims.empty())
207       return failure();
208 
209     // Compute the modified indexing maps.
210     MLIRContext *context = rewriter.getContext();
211     ArrayAttr newIndexingMapAttr =
212         replaceUnitDims(unitDims, indexingMaps, context);
213     if (!newIndexingMapAttr)
214       return op.emitError("unable to compute modified indexing_maps");
215 
216     // Compute the iterator types of the modified op by dropping the one-trip
217     // count loops.
218     SmallVector<Attribute, 4> newIteratorTypes;
219     for (auto attr : llvm::enumerate(iteratorTypes)) {
220       if (!unitDims.count(attr.index()))
221         newIteratorTypes.push_back(attr.value());
222     }
223 
224     rewriter.startRootUpdate(op);
225     op.indexing_mapsAttr(newIndexingMapAttr);
226     op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
227     replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
228     rewriter.finalizeRootUpdate(op);
229     return success();
230   }
231 };
232 
233 struct UnitExtentReplacementInfo {
234   RankedTensorType type;
235   AffineMap indexMap;
236   ArrayAttr reassociation;
237 };
238 } // namespace
239 
240 /// Utility function for replacing operands/results to a linalg generic
241 /// operation on tensors with unit-extent dimensions. These can be replaced with
242 /// an operand/result with the unit-extent dimension removed. This is only done
243 /// if the indexing map used to access that didimensionmension has a
244 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
245 /// Linalg op, and its `indexMap` the utility function returns:
246 /// - the new type with dimensions of size 1 removed.
247 /// - modified index map that can be used to access the replaced result/operand
248 /// - the reassociation that converts from the original tensor type to the
249 ///   modified tensor type.
250 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
251                                                     RankedTensorType type,
252                                                     MLIRContext *context) {
253   ArrayRef<int64_t> shape = type.getShape();
254   ArrayRef<AffineExpr> exprs = indexMap.getResults();
255   SmallVector<AffineExpr, 2> reassociations;
256   SmallVector<Attribute, 4> reassociationMaps;
257   SmallVector<AffineExpr, 4> newIndexExprs;
258   SmallVector<int64_t, 4> newShape;
259 
260   int64_t origRank = type.getRank();
261   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
262   auto isUnitExtent = [&](int64_t dim) -> bool {
263     return shape[dim] == 1 && exprs[dim] == zeroExpr;
264   };
265 
266   unsigned dim = 0;
267   // Fold dimensions that are unit-extent at the beginning of the tensor.
268   while (dim < origRank && isUnitExtent(dim))
269     reassociations.push_back(getAffineDimExpr(dim++, context));
270   while (dim < origRank) {
271     reassociations.push_back(getAffineDimExpr(dim, context));
272     newIndexExprs.push_back(exprs[dim]);
273     newShape.push_back(shape[dim]);
274     // Fold all following dimensions that are unit-extent.
275     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
276       ++dim;
277       reassociations.push_back(getAffineDimExpr(dim, context));
278     }
279     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
280         origRank, /*numSymbols = */ 0, reassociations, context)));
281     reassociations.clear();
282     ++dim;
283   }
284   UnitExtentReplacementInfo info = {
285       RankedTensorType::get(newShape, type.getElementType()),
286       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
287                      newIndexExprs, context),
288       ArrayAttr::get(reassociationMaps, context)};
289   return info;
290 }
291 
292 namespace {
293 
294 /// Pattern to replace tensors operands/results that are unit extents.
295 template <typename GenericOpTy>
296 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
297   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
298   LogicalResult matchAndRewrite(GenericOpTy op,
299                                 PatternRewriter &rewriter) const override {
300     // TODO: support init_tensors and reductions.
301     if (!op.hasTensorSemantics() || !op.init_tensors().empty())
302       return failure();
303 
304     MLIRContext *context = rewriter.getContext();
305     Location loc = op.getLoc();
306 
307     SmallVector<AffineMap, 4> newIndexingMaps;
308     SmallVector<ArrayAttr, 4> reassociationMaps;
309     SmallVector<ShapedType, 4> newInputOutputTypes;
310     bool doCanonicalization = false;
311     for (auto it :
312          llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) {
313       auto replacementInfo = replaceUnitExtents(
314           std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
315           context);
316       reassociationMaps.push_back(replacementInfo.reassociation);
317       newIndexingMaps.push_back(replacementInfo.indexMap);
318       newInputOutputTypes.push_back(replacementInfo.type);
319       doCanonicalization |= replacementInfo.type != std::get<1>(it);
320     }
321 
322     // If the indexing maps of the result operation are not invertible (i.e. not
323     // legal), abort.
324     if (!doCanonicalization ||
325         !inversePermutation(concatAffineMaps(newIndexingMaps)))
326       return failure();
327 
328     // If any operand type change, insert a reshape to convert from the original
329     // type to the new type.
330     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
331     unsigned flattenedIdx = 0;
332     auto insertReshapes = [&](ValueRange values) {
333       SmallVector<Value, 4> res;
334       res.reserve(values.size());
335       for (auto operand : llvm::enumerate(values)) {
336         if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
337           res.push_back(operand.value());
338         else
339           res.push_back(rewriter.create<linalg::TensorReshapeOp>(
340               loc, newInputOutputTypes[flattenedIdx], operand.value(),
341               reassociationMaps[flattenedIdx]));
342         ++flattenedIdx;
343       }
344       return res;
345     };
346 
347     SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
348     SmallVector<Value, 4> newOutputBuffers =
349         insertReshapes(op.output_buffers());
350     SmallVector<Value, 4> newInitTensors = insertReshapes(op.init_tensors());
351 
352     // If any result type change, insert a reshape to convert from the original
353     // type to the new type.
354     SmallVector<Type, 4> resultTypes;
355     resultTypes.reserve(op.getNumResults());
356     for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
357       resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
358     GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
359         loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
360         newIndexingMaps,
361         llvm::to_vector<4>(
362             op.iterator_types().template getAsValueRange<StringAttr>()));
363     rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
364                                 replacementOp.region().begin());
365 
366     // If any result tensor has a modified shape, then add reshape to recover
367     // the original shape.
368     SmallVector<Value, 4> resultReplacements;
369     for (auto result : llvm::enumerate(replacementOp.getResults())) {
370       unsigned index = result.index() + replacementOp.getNumOperands();
371       RankedTensorType origResultType = op.getResult(result.index())
372                                             .getType()
373                                             .template cast<RankedTensorType>();
374       if (origResultType != result.value().getType())
375         resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
376             loc, origResultType, result.value(), reassociationMaps[index]));
377       else
378         resultReplacements.push_back(result.value());
379     }
380     rewriter.replaceOp(op, resultReplacements);
381     return success();
382   }
383 };
384 } // namespace
385 
386 namespace {
387 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
388 /// example:
389 ///
390 ///  %0 = linalg.tensor_reshape %arg0
391 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
392 ///    : tensor<2048xf32> into tensor<1x4x1x512xf32>
393 ///  %1 = linalg.tensor_reshape %0
394 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
395 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
396 ///    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
397 ///
398 /// can be replaced with
399 ///
400 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
401 ///    : tensor<2048xf32> into tensor<4x512xf32>
402 ///
403 /// Similarly,
404 ///
405 ///  %0 = linalg.tensor_reshape %arg0
406 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
407 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
408 ///    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
409 ///  %1 = linalg.tensor_reshape %0
410 ///   [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
411 ///    : tensor<1x4x1x512xf32> into tensor<2048xf32>
412 ///
413 /// can be replaced with
414 ///
415 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
416 ///    : tensor<4x512xf32> into tensor<2048xf32>
417 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
418   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
419 
420   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
421                                 PatternRewriter &rewriter) const override {
422     // Check that the source operand is created from a reshape as well.
423     TensorReshapeOp parentReshapeOp =
424         reshapeOp.src().getDefiningOp<TensorReshapeOp>();
425     if (!parentReshapeOp)
426       return failure();
427 
428     RankedTensorType srcType = reshapeOp.getSrcType(),
429                      dstType = reshapeOp.getResultType(),
430                      parentSrcType = parentReshapeOp.getSrcType();
431     if (!srcType.hasStaticShape() || !dstType.hasStaticShape() ||
432         !parentSrcType.hasStaticShape() ||
433         srcType.getRank() < dstType.getRank() ||
434         parentSrcType.getRank() == dstType.getRank())
435       return failure();
436 
437     // Check if the result tensor_reshape after folding the reshapeOp and
438     // parentReshapeOp are combined.
439     // If the final tensor_reshape is folding, the parentReshapeOp is
440     // introducing unit-dims, and the reshapeOp does an actual reshape.
441     // If the final tensor_reshape op is expanding, the reshapeOp is
442     // introducing unit-dims, and the parentReshapeOp does an actual reshape.
443     bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
444     ArrayRef<int64_t> expandedShape =
445         isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
446     ArrayRef<int64_t> foldedShape =
447         isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
448 
449     unsigned expandedDim = 0, foldedDim = 0;
450     SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
451         foldedShape.size());
452     while (expandedDim < expandedShape.size() &&
453            foldedDim < foldedShape.size()) {
454       int64_t dstSize = foldedShape[foldedDim];
455       int64_t srcSize = expandedShape[expandedDim];
456       while (srcSize < dstSize && expandedDim < expandedShape.size()) {
457         reassociationExprs[foldedDim].push_back(
458             rewriter.getAffineDimExpr(expandedDim++));
459         srcSize *= expandedShape[expandedDim];
460       }
461       if (srcSize == dstSize) {
462         reassociationExprs[foldedDim].push_back(
463             rewriter.getAffineDimExpr(expandedDim++));
464         // If the next dim in foldedShape is not 1, treat subsequent dims in
465         // expandedShape which are 1 to be collapsed.
466         if (foldedDim == foldedShape.size() - 1 ||
467             foldedShape[foldedDim + 1] != 1) {
468           while (expandedDim < expandedShape.size() &&
469                  expandedShape[expandedDim] == 1) {
470             reassociationExprs[foldedDim].push_back(
471                 rewriter.getAffineDimExpr(expandedDim++));
472           }
473         }
474       } else {
475         return failure();
476       }
477       foldedDim++;
478     }
479     if (expandedDim != expandedShape.size())
480       return failure();
481 
482     SmallVector<AffineMap, 4> reassociationMaps =
483         llvm::to_vector<4>(llvm::map_range(
484             reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
485               return AffineMap::get(expandedShape.size(), 0, exprs,
486                                     rewriter.getContext());
487             }));
488     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
489         reshapeOp, dstType, parentReshapeOp.src(),
490         rewriter.getAffineMapArrayAttr(reassociationMaps));
491     return success();
492   }
493 };
494 } // namespace
495 
496 /// Patterns that are used to canonicalize the use of unit-extent dims for
497 /// broadcasting.
498 void mlir::populateLinalgFoldUnitExtentDimsPatterns(
499     MLIRContext *context, OwningRewritePatternList &patterns) {
500   patterns
501       .insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
502               ReplaceUnitExtentTensors<GenericOp>,
503               ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
504   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
505   patterns.insert<FoldReshapeOpWithUnitExtent>(context);
506 }
507 
508 namespace {
509 /// Pass that removes unit-extent dims within generic ops.
510 struct LinalgFoldUnitExtentDimsPass
511     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
512   void runOnFunction() override {
513     OwningRewritePatternList patterns;
514     FuncOp funcOp = getFunction();
515     MLIRContext *context = funcOp.getContext();
516     if (foldOneTripLoopsOnly)
517       patterns.insert<FoldUnitDimLoops<GenericOp>,
518                       FoldUnitDimLoops<IndexedGenericOp>>(context);
519     else
520       populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
521     applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
522   }
523 };
524 } // namespace
525 
526 std::unique_ptr<OperationPass<FuncOp>>
527 mlir::createLinalgFoldUnitExtentDimsPass() {
528   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
529 }
530