1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "linalg-drop-unit-dims"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
34 /// broadcasting. For example,
35 ///
36 /// ```mlir
37 /// #accesses = [
38 ///   affine_map<(d0, d1) -> (0, d1)>,
39 ///   affine_map<(d0, d1) -> (d0, 0)>,
40 ///   affine_map<(d0, d1) -> (d0, d1)>
41 /// ]
42 ///
43 /// #trait = {
44 ///   args_in = 2,
45 ///   args_out = 1,
46 ///   indexing_maps = #accesses,
47 ///   iterator_types = ["parallel", "parallel"],
48 ///   library_call = "some_external_fn"
49 /// }
50 ///
51 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
52 /// tensor<5x5xf32>
53 /// {
54 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
55 ///        tensor<5xf32> into tensor<1x5xf32>
56 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
57 ///        tensor<5xf32> into tensor<5x1xf32>
58 ///   %2 = linalg.generic #trait %0, %1 {
59 ///        ^bb0(%arg2: f32, %arg3: f32):
60 ///          %3 = addf %arg2, %arg3 : f32
61 ///          linalg.yield %3 : f32
62 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
63 ///   return %2 : tensor<5x5xf32>
64 /// }
65 ///
66 /// would canonicalize to
67 ///
68 /// ```mlir
69 /// #accesses = [
70 ///   affine_map<(d0, d1) -> (d1)>,
71 ///   affine_map<(d0, d1) -> (d0)>,
72 ///   affine_map<(d0, d1) -> (d0, d1)>
73 /// ]
74 ///
75 /// #trait = {
76 ///   args_in = 2,
77 ///   args_out = 1,
78 ///   indexing_maps = #accesses,
79 ///   iterator_types = ["parallel", "parallel"],
80 ///   library_call = "some_external_fn"
81 /// }
82 ///
83 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
84 /// tensor<5x5xf32>
85 /// {
86 ///   %0 = linalg.generic #trait %arg0, %arg1 {
87 ///        ^bb0(%arg2: f32, %arg3: f32):
88 ///          %3 = addf %arg2, %arg3 : f32
89 ///          linalg.yield %3 : f32
90 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
91 ///   return %0 : tensor<5x5xf32>
92 /// }
93 
94 /// Given dims of the iteration space of a structured op that are known to be
95 /// single trip count (`unitDims`), return the indexing maps to use in the
96 /// canonicalized op with these dims removed, given the original `indexingMaps`.
97 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
98                                  ArrayRef<AffineMap> indexingMaps,
99                                  MLIRContext *context) {
100   if (indexingMaps.empty())
101     return nullptr;
102   unsigned numIterationDims = indexingMaps.front().getNumDims();
103   unsigned numSymbols = indexingMaps.front().getNumSymbols();
104 
105   // Compute the replacement for each dim expr.
106   SmallVector<AffineExpr, 4> dimReplacements;
107   dimReplacements.reserve(numIterationDims);
108   unsigned numKeptDims = 0;
109   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
110     if (unitDims.count(dim))
111       dimReplacements.push_back(getAffineConstantExpr(0, context));
112     else
113       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
114   }
115 
116   // Symbols remain the same.
117   SmallVector<AffineExpr, 4> symReplacements;
118   symReplacements.reserve(numSymbols);
119   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
120     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
121 
122   SmallVector<AffineMap, 4> newIndexingMaps;
123   newIndexingMaps.reserve(indexingMaps.size());
124   for (AffineMap operandMap : indexingMaps) {
125     // Expected indexing maps to have no symbols.
126     if (operandMap.getNumSymbols())
127       return nullptr;
128     newIndexingMaps.push_back(simplifyAffineMap(
129         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
130                                          numIterationDims - unitDims.size(),
131                                          numSymbols)));
132   }
133 
134   // Check that the new index maps are invertible. If not, something went
135   // wrong, so abort.
136   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
137     return nullptr;
138   return ArrayAttr::get(context,
139                         llvm::to_vector<4>(llvm::map_range(
140                             newIndexingMaps, [](AffineMap map) -> Attribute {
141                               return AffineMapAttr::get(map);
142                             })));
143 }
144 
145 /// Update the index accesses of linalg operations having index semantics.
146 static void replaceUnitDimIndexOps(GenericOp genericOp,
147                                    const DenseSet<unsigned> &unitDims,
148                                    PatternRewriter &rewriter) {
149   assert(genericOp->getNumRegions() == 1 &&
150          genericOp->getRegion(0).getBlocks().size() == 1 &&
151          "expected generic operation to have one block.");
152   Block &block = genericOp->getRegion(0).front();
153 
154   for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
155     OpBuilder::InsertionGuard guard(rewriter);
156     rewriter.setInsertionPoint(indexOp);
157     if (unitDims.count(indexOp.dim()) != 0) {
158       rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0);
159     } else {
160       // Update the dimension of the index operation if needed.
161       unsigned droppedDims = llvm::count_if(
162           unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
163       if (droppedDims != 0)
164         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
165                                              indexOp.dim() - droppedDims);
166     }
167   }
168 }
169 
170 namespace {
171 /// Pattern to fold unit-trip count loops in GenericOps.
172 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
173   using OpRewritePattern<GenericOp>::OpRewritePattern;
174   LogicalResult matchAndRewrite(GenericOp genericOp,
175                                 PatternRewriter &rewriter) const override {
176     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
177     if (indexingMaps.empty())
178       return failure();
179 
180     // Check if any of the iteration dimensions are unit-trip count. They will
181     // end up being unit-trip count if they are used to index into a unit-dim
182     // tensor/memref.
183     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
184     if (!invertedMap)
185       return failure();
186     SmallVector<int64_t, 4> dims;
187     for (ShapedType shapedType : genericOp.getShapedOperandTypes())
188       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
189 
190     // Find all the reduction iterators. Those need some special consideration
191     // (see below).
192     auto getLoopDimsOfType =
193         [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
194       SmallVector<AffineExpr> dimExprs;
195       getDimsOfType(genericOp, iteratorTypeName, dimExprs);
196       return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
197         return expr.cast<AffineDimExpr>().getPosition();
198       }));
199     };
200     auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
201 
202     DenseSet<unsigned> unitDims;
203     SmallVector<unsigned, 4> unitDimsReductionLoops;
204     ArrayAttr iteratorTypes = genericOp.iterator_types();
205     for (auto expr : enumerate(invertedMap.getResults())) {
206       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
207         if (dims[dimExpr.getPosition()] == 1) {
208           if (isParallelIterator(iteratorTypes[expr.index()]))
209             unitDims.insert(expr.index());
210           else if (isReductionIterator(iteratorTypes[expr.index()]))
211             unitDimsReductionLoops.push_back(expr.index());
212         }
213     }
214 
215     // Reduction loops can be dropped if there is at least one other reduction
216     // loop that is not dropped. This accounts for the initial value read in the
217     // reduction loop.
218     if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
219       if (unitDimsReductionLoops.size() == reductionDims.size())
220         unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
221       else
222         unitDims.insert(unitDimsReductionLoops.begin(),
223                         unitDimsReductionLoops.end());
224     }
225 
226     if (unitDims.empty())
227       return failure();
228 
229     // Compute the modified indexing maps.
230     MLIRContext *context = rewriter.getContext();
231     ArrayAttr newIndexingMapAttr =
232         replaceUnitDims(unitDims, indexingMaps, context);
233     if (!newIndexingMapAttr)
234       return genericOp.emitError("unable to compute modified indexing_maps");
235 
236     // Compute the iterator types of the modified op by dropping the one-trip
237     // count loops.
238     SmallVector<Attribute, 4> newIteratorTypes;
239     for (auto attr : llvm::enumerate(iteratorTypes)) {
240       if (!unitDims.count(attr.index()))
241         newIteratorTypes.push_back(attr.value());
242     }
243 
244     rewriter.startRootUpdate(genericOp);
245     genericOp.indexing_mapsAttr(newIndexingMapAttr);
246     genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
247     replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
248     rewriter.finalizeRootUpdate(genericOp);
249     return success();
250   }
251 };
252 
253 struct UnitExtentReplacementInfo {
254   RankedTensorType type;
255   AffineMap indexMap;
256   ArrayAttr reassociation;
257 };
258 } // namespace
259 
260 /// Utility function for replacing operands/results to a linalg generic
261 /// operation on tensors with unit-extent dimensions. These can be replaced with
262 /// an operand/result with the unit-extent dimension removed. This is only done
263 /// if the indexing map used to access that didimensionmension has a
264 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
265 /// Linalg op, and its `indexMap` the utility function returns:
266 /// - the new type with dimensions of size 1 removed.
267 /// - modified index map that can be used to access the replaced result/operand
268 /// - the reassociation that converts from the original tensor type to the
269 ///   modified tensor type.
270 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
271                                                     RankedTensorType type,
272                                                     MLIRContext *context) {
273   ArrayRef<int64_t> shape = type.getShape();
274   ArrayRef<AffineExpr> exprs = indexMap.getResults();
275   SmallVector<AffineExpr, 2> reassociations;
276   SmallVector<Attribute, 4> reassociationMaps;
277   SmallVector<AffineExpr, 4> newIndexExprs;
278   SmallVector<int64_t, 4> newShape;
279 
280   int64_t origRank = type.getRank();
281   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
282   auto isUnitExtent = [&](int64_t dim) -> bool {
283     return shape[dim] == 1 && exprs[dim] == zeroExpr;
284   };
285 
286   unsigned dim = 0;
287   // Fold dimensions that are unit-extent at the beginning of the tensor.
288   while (dim < origRank && isUnitExtent(dim))
289     reassociations.push_back(getAffineDimExpr(dim++, context));
290   while (dim < origRank) {
291     reassociations.push_back(getAffineDimExpr(dim, context));
292     newIndexExprs.push_back(exprs[dim]);
293     newShape.push_back(shape[dim]);
294     // Fold all following dimensions that are unit-extent.
295     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
296       ++dim;
297       reassociations.push_back(getAffineDimExpr(dim, context));
298     }
299     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
300         origRank, /*symbolCount = */ 0, reassociations, context)));
301     reassociations.clear();
302     ++dim;
303   }
304   UnitExtentReplacementInfo info = {
305       RankedTensorType::get(newShape, type.getElementType()),
306       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
307                      newIndexExprs, context),
308       ArrayAttr::get(context, reassociationMaps)};
309   return info;
310 }
311 
312 namespace {
313 
314 SmallVector<ReassociationExprs, 2>
315 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
316   SmallVector<ReassociationExprs, 2> reassociationExprs;
317   for (auto attr : affineMapArrayAttr)
318     reassociationExprs.push_back(
319         llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
320   return reassociationExprs;
321 }
322 
323 /// Pattern to replace tensors operands/results that are unit extents.
324 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
325   using OpRewritePattern<GenericOp>::OpRewritePattern;
326   LogicalResult matchAndRewrite(GenericOp genericOp,
327                                 PatternRewriter &rewriter) const override {
328     if (!genericOp.hasTensorSemantics())
329       return failure();
330 
331     MLIRContext *context = rewriter.getContext();
332     Location loc = genericOp.getLoc();
333 
334     SmallVector<AffineMap, 4> newIndexingMaps;
335     SmallVector<ArrayAttr, 4> reassociationMaps;
336     SmallVector<ShapedType, 4> newInputOutputTypes;
337     bool doCanonicalization = false;
338     for (auto it : llvm::zip(genericOp.getIndexingMaps(),
339                              genericOp.getShapedOperandTypes())) {
340       auto replacementInfo = replaceUnitExtents(
341           std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
342           context);
343       reassociationMaps.push_back(replacementInfo.reassociation);
344       newIndexingMaps.push_back(replacementInfo.indexMap);
345       newInputOutputTypes.push_back(replacementInfo.type);
346       doCanonicalization |= replacementInfo.type != std::get<1>(it);
347     }
348 
349     // If the indexing maps of the result operation are not invertible (i.e. not
350     // legal), abort.
351     if (!doCanonicalization ||
352         !inversePermutation(concatAffineMaps(newIndexingMaps)))
353       return failure();
354 
355     // If any operand type change, insert a reshape to convert from the original
356     // type to the new type.
357     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
358     unsigned flattenedIdx = 0;
359     auto insertReshapes = [&](ValueRange values) {
360       SmallVector<Value, 4> res;
361       res.reserve(values.size());
362       for (auto operand : llvm::enumerate(values)) {
363         if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
364           res.push_back(operand.value());
365         else
366           res.push_back(rewriter.create<linalg::TensorReshapeOp>(
367               loc, newInputOutputTypes[flattenedIdx], operand.value(),
368               convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
369         ++flattenedIdx;
370       }
371       return res;
372     };
373 
374     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
375     SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
376 
377     // If any result type changes, insert a reshape to convert from the original
378     // type to the new type.
379     SmallVector<Type, 4> resultTypes;
380     resultTypes.reserve(genericOp.getNumResults());
381     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
382       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
383     GenericOp replacementOp = rewriter.create<GenericOp>(
384         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
385         llvm::to_vector<4>(
386             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
387     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
388                                 replacementOp.region().begin());
389 
390     // If any result tensor has a modified shape, then add reshape to recover
391     // the original shape.
392     SmallVector<Value, 4> resultReplacements;
393     for (auto result : llvm::enumerate(replacementOp.getResults())) {
394       unsigned index = result.index() + replacementOp.getNumInputs();
395       RankedTensorType origResultType = genericOp.getResult(result.index())
396                                             .getType()
397                                             .template cast<RankedTensorType>();
398       if (origResultType != result.value().getType())
399         resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
400             loc, origResultType, result.value(),
401             convertAffineMapArrayToExprs(reassociationMaps[index])));
402       else
403         resultReplacements.push_back(result.value());
404     }
405     rewriter.replaceOp(genericOp, resultReplacements);
406     return success();
407   }
408 };
409 } // namespace
410 
411 /// Get the reassociation maps to fold the result of a subtensor (or source of a
412 /// subtensor_insert) operation with given offsets, and sizes to its
413 /// rank-reduced version. This is only done for the cases where the size is 1
414 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
415 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
416 /// potentially cannot be handled).
417 static Optional<SmallVector<ReassociationIndices>>
418 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
419   SmallVector<ReassociationIndices> reassociation;
420   ReassociationIndices curr;
421   for (auto it : llvm::enumerate(mixedSizes)) {
422     auto dim = it.index();
423     auto size = it.value();
424     curr.push_back(dim);
425     auto attr = size.dyn_cast<Attribute>();
426     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
427       continue;
428     reassociation.emplace_back(ReassociationIndices{});
429     std::swap(reassociation.back(), curr);
430   }
431   // When the reassociations are not empty, then fold the remaining
432   // unit-dimensions into the last dimension.  If the reassociations so far is
433   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
434   if (!curr.empty() && !reassociation.empty())
435     reassociation.back().append(curr.begin(), curr.end());
436   return reassociation;
437 }
438 
439 namespace {
440 /// Convert `subtensor` operations to rank-reduced versions.
441 struct UseRankReducedSubTensorOp : public OpRewritePattern<SubTensorOp> {
442   using OpRewritePattern<SubTensorOp>::OpRewritePattern;
443 
444   LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
445                                 PatternRewriter &rewriter) const override {
446     RankedTensorType resultType = subTensorOp.getType();
447     SmallVector<OpFoldResult> offsets = subTensorOp.getMixedOffsets();
448     SmallVector<OpFoldResult> sizes = subTensorOp.getMixedSizes();
449     SmallVector<OpFoldResult> strides = subTensorOp.getMixedStrides();
450     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
451     if (!reassociation ||
452         reassociation->size() == static_cast<size_t>(resultType.getRank()))
453       return failure();
454     auto rankReducedType =
455         SubTensorOp::inferRankReducedResultType(reassociation->size(),
456                                                 subTensorOp.getSourceType(),
457                                                 offsets, sizes, strides)
458             .cast<RankedTensorType>();
459 
460     Location loc = subTensorOp.getLoc();
461     Value newSubTensor = rewriter.create<SubTensorOp>(
462         loc, rankReducedType, subTensorOp.source(), offsets, sizes, strides);
463     rewriter.replaceOpWithNewOp<TensorReshapeOp>(subTensorOp, resultType,
464                                                  newSubTensor, *reassociation);
465     return success();
466   }
467 };
468 
469 /// Convert `subtensor_insert` operations to rank-reduced versions.
470 struct UseRankReducedSubTensorInsertOp
471     : public OpRewritePattern<SubTensorInsertOp> {
472   using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
473 
474   LogicalResult matchAndRewrite(SubTensorInsertOp insertOp,
475                                 PatternRewriter &rewriter) const override {
476     RankedTensorType sourceType = insertOp.getSourceType();
477     SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
478     SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
479     SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
480     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
481     if (!reassociation ||
482         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
483       return failure();
484     Location loc = insertOp.getLoc();
485     auto reshapedSource = rewriter.create<TensorReshapeOp>(
486         loc, insertOp.source(), *reassociation);
487     rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
488         insertOp, reshapedSource, insertOp.dest(), insertOp.getMixedOffsets(),
489         insertOp.getMixedSizes(), insertOp.getMixedStrides());
490     return success();
491   }
492 };
493 } // namespace
494 
495 /// Patterns that are used to canonicalize the use of unit-extent dims for
496 /// broadcasting.
497 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
498     RewritePatternSet &patterns) {
499   auto *context = patterns.getContext();
500   patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
501                UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
502       context);
503   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
504 }
505 
506 namespace {
507 /// Pass that removes unit-extent dims within generic ops.
508 struct LinalgFoldUnitExtentDimsPass
509     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
510   void runOnFunction() override {
511     FuncOp funcOp = getFunction();
512     MLIRContext *context = funcOp.getContext();
513     RewritePatternSet patterns(context);
514     if (foldOneTripLoopsOnly)
515       patterns.add<FoldUnitDimLoops>(context);
516     else
517       populateFoldUnitExtentDimsPatterns(patterns);
518     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
519   }
520 };
521 } // namespace
522 
523 std::unique_ptr<OperationPass<FuncOp>>
524 mlir::createLinalgFoldUnitExtentDimsPass() {
525   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
526 }
527