1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "linalg-drop-unit-dims"
29 
30 using namespace mlir;
31 using namespace mlir::edsc;
32 using namespace mlir::edsc::intrinsics;
33 using namespace mlir::linalg;
34 
35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
36 /// broadcasting. For example,
37 ///
38 /// ```mlir
39 /// #accesses = [
40 ///   affine_map<(d0, d1) -> (0, d1)>,
41 ///   affine_map<(d0, d1) -> (d0, 0)>,
42 ///   affine_map<(d0, d1) -> (d0, d1)>
43 /// ]
44 ///
45 /// #trait = {
46 ///   args_in = 2,
47 ///   args_out = 1,
48 ///   indexing_maps = #accesses,
49 ///   iterator_types = ["parallel", "parallel"],
50 ///   library_call = "some_external_fn"
51 /// }
52 ///
53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
54 /// tensor<5x5xf32>
55 /// {
56 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
57 ///        tensor<5xf32> into tensor<1x5xf32>
58 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
59 ///        tensor<5xf32> into tensor<5x1xf32>
60 ///   %2 = linalg.generic #trait %0, %1 {
61 ///        ^bb0(%arg2: f32, %arg3: f32):
62 ///          %3 = addf %arg2, %arg3 : f32
63 ///          linalg.yield %3 : f32
64 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
65 ///   return %2 : tensor<5x5xf32>
66 /// }
67 ///
68 /// would canonicalize to
69 ///
70 /// ```mlir
71 /// #accesses = [
72 ///   affine_map<(d0, d1) -> (d1)>,
73 ///   affine_map<(d0, d1) -> (d0)>,
74 ///   affine_map<(d0, d1) -> (d0, d1)>
75 /// ]
76 ///
77 /// #trait = {
78 ///   args_in = 2,
79 ///   args_out = 1,
80 ///   indexing_maps = #accesses,
81 ///   iterator_types = ["parallel", "parallel"],
82 ///   library_call = "some_external_fn"
83 /// }
84 ///
85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
86 /// tensor<5x5xf32>
87 /// {
88 ///   %0 = linalg.generic #trait %arg0, %arg1 {
89 ///        ^bb0(%arg2: f32, %arg3: f32):
90 ///          %3 = addf %arg2, %arg3 : f32
91 ///          linalg.yield %3 : f32
92 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
93 ///   return %0 : tensor<5x5xf32>
94 /// }
95 
96 /// Given dims of the iteration space of a structured op that are known to be
97 /// single trip count (`unitDims`), return the indexing maps to use in the
98 /// canonicalized op with these dims removed, given the original `indexingMaps`.
99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
100                                  ArrayRef<AffineMap> indexingMaps,
101                                  MLIRContext *context) {
102   if (indexingMaps.empty())
103     return nullptr;
104   unsigned numIterationDims = indexingMaps.front().getNumDims();
105   unsigned numSymbols = indexingMaps.front().getNumSymbols();
106 
107   // Compute the replacement for each dim expr.
108   SmallVector<AffineExpr, 4> dimReplacements;
109   dimReplacements.reserve(numIterationDims);
110   unsigned numKeptDims = 0;
111   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
112     if (unitDims.count(dim))
113       dimReplacements.push_back(getAffineConstantExpr(0, context));
114     else
115       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
116   }
117 
118   // Symbols remain the same.
119   SmallVector<AffineExpr, 4> symReplacements;
120   symReplacements.reserve(numSymbols);
121   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
122     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
123 
124   SmallVector<AffineMap, 4> newIndexingMaps;
125   newIndexingMaps.reserve(indexingMaps.size());
126   for (AffineMap operandMap : indexingMaps) {
127     // Expected indexing maps to have no symbols.
128     if (operandMap.getNumSymbols())
129       return nullptr;
130     newIndexingMaps.push_back(simplifyAffineMap(
131         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
132                                          numIterationDims - unitDims.size(),
133                                          numSymbols)));
134   }
135 
136   // Check that the new index maps are invertible. If not, something went
137   // wrong, so abort.
138   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
139     return nullptr;
140   return ArrayAttr::get(context,
141                         llvm::to_vector<4>(llvm::map_range(
142                             newIndexingMaps, [](AffineMap map) -> Attribute {
143                               return AffineMapAttr::get(map);
144                             })));
145 }
146 
147 /// Modify the region of indexed generic op to drop arguments corresponding to
148 /// loops that are unit trip count.
149 template <typename OpTy>
150 static LogicalResult
151 replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
152                                PatternRewriter &rewriterp) {
153   return success();
154 }
155 
156 template <>
157 LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
158     IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
159     PatternRewriter &rewriter) {
160   OpBuilder::InsertionGuard guard(rewriter);
161   Block *entryBlock = &op->getRegion(0).front();
162   rewriter.setInsertionPointToStart(entryBlock);
163   Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
164   for (unsigned unitDimLoop : unitDims) {
165     entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
166   }
167   SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
168   entryBlock->eraseArguments(unitDimsToErase);
169   return success();
170 }
171 
172 namespace {
173 /// Pattern to fold unit-trip count loops in GenericOps.
174 template <typename GenericOpTy>
175 struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
176   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
177   LogicalResult matchAndRewrite(GenericOpTy op,
178                                 PatternRewriter &rewriter) const override {
179     SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
180     if (indexingMaps.empty())
181       return failure();
182 
183     // Check if any of the iteration dimensions are unit-trip count. They will
184     // end up being unit-trip count if they are used to index into a unit-dim
185     // tensor/memref.
186     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
187     if (!invertedMap)
188       return failure();
189     SmallVector<int64_t, 4> dims;
190     for (ShapedType shapedType : op.getShapedOperandTypes())
191       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
192     DenseSet<unsigned> unitDims;
193     ArrayAttr iteratorTypes = op.iterator_types();
194     for (auto expr : enumerate(invertedMap.getResults())) {
195       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
196         if (dims[dimExpr.getPosition()] == 1 &&
197             iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
198                 getParallelIteratorTypeName())
199           unitDims.insert(expr.index());
200     }
201     if (unitDims.empty())
202       return failure();
203 
204     // Compute the modified indexing maps.
205     MLIRContext *context = rewriter.getContext();
206     ArrayAttr newIndexingMapAttr =
207         replaceUnitDims(unitDims, indexingMaps, context);
208     if (!newIndexingMapAttr)
209       return op.emitError("unable to compute modified indexing_maps");
210 
211     // Compute the iterator types of the modified op by dropping the one-trip
212     // count loops.
213     SmallVector<Attribute, 4> newIteratorTypes;
214     for (auto attr : llvm::enumerate(iteratorTypes)) {
215       if (!unitDims.count(attr.index()))
216         newIteratorTypes.push_back(attr.value());
217     }
218 
219     rewriter.startRootUpdate(op);
220     op.indexing_mapsAttr(newIndexingMapAttr);
221     op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
222     (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
223     rewriter.finalizeRootUpdate(op);
224     return success();
225   }
226 };
227 
228 struct UnitExtentReplacementInfo {
229   RankedTensorType type;
230   AffineMap indexMap;
231   ArrayAttr reassociation;
232 };
233 } // namespace
234 
235 /// Utility function for replacing operands/results to a linalg generic
236 /// operation on tensors with unit-extent dimensions. These can be replaced with
237 /// an operand/result with the unit-extent dimension removed. This is only done
238 /// if the indexing map used to access that didimensionmension has a
239 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
240 /// Linalg op, and its `indexMap` the utility function returns:
241 /// - the new type with dimensions of size 1 removed.
242 /// - modified index map that can be used to access the replaced result/operand
243 /// - the reassociation that converts from the original tensor type to the
244 ///   modified tensor type.
245 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
246                                                     RankedTensorType type,
247                                                     MLIRContext *context) {
248   ArrayRef<int64_t> shape = type.getShape();
249   ArrayRef<AffineExpr> exprs = indexMap.getResults();
250   SmallVector<AffineExpr, 2> reassociations;
251   SmallVector<Attribute, 4> reassociationMaps;
252   SmallVector<AffineExpr, 4> newIndexExprs;
253   SmallVector<int64_t, 4> newShape;
254 
255   int64_t origRank = type.getRank();
256   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
257   auto isUnitExtent = [&](int64_t dim) -> bool {
258     return shape[dim] == 1 && exprs[dim] == zeroExpr;
259   };
260 
261   unsigned dim = 0;
262   // Fold dimensions that are unit-extent at the beginning of the tensor.
263   while (dim < origRank && isUnitExtent(dim))
264     reassociations.push_back(getAffineDimExpr(dim++, context));
265   while (dim < origRank) {
266     reassociations.push_back(getAffineDimExpr(dim, context));
267     newIndexExprs.push_back(exprs[dim]);
268     newShape.push_back(shape[dim]);
269     // Fold all following dimensions that are unit-extent.
270     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
271       ++dim;
272       reassociations.push_back(getAffineDimExpr(dim, context));
273     }
274     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
275         origRank, /*numSymbols = */ 0, reassociations, context)));
276     reassociations.clear();
277     ++dim;
278   }
279   UnitExtentReplacementInfo info = {
280       RankedTensorType::get(newShape, type.getElementType()),
281       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
282                      newIndexExprs, context),
283       ArrayAttr::get(context, reassociationMaps)};
284   return info;
285 }
286 
287 namespace {
288 
289 /// Pattern to replace tensors operands/results that are unit extents.
290 template <typename GenericOpTy>
291 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
292   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
293   LogicalResult matchAndRewrite(GenericOpTy op,
294                                 PatternRewriter &rewriter) const override {
295     // TODO: support reductions.
296     if (!op.hasTensorSemantics())
297       return failure();
298 
299     MLIRContext *context = rewriter.getContext();
300     Location loc = op.getLoc();
301 
302     SmallVector<AffineMap, 4> newIndexingMaps;
303     SmallVector<ArrayAttr, 4> reassociationMaps;
304     SmallVector<ShapedType, 4> newInputOutputTypes;
305     bool doCanonicalization = false;
306     for (auto it :
307          llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) {
308       auto replacementInfo = replaceUnitExtents(
309           std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
310           context);
311       reassociationMaps.push_back(replacementInfo.reassociation);
312       newIndexingMaps.push_back(replacementInfo.indexMap);
313       newInputOutputTypes.push_back(replacementInfo.type);
314       doCanonicalization |= replacementInfo.type != std::get<1>(it);
315     }
316 
317     // If the indexing maps of the result operation are not invertible (i.e. not
318     // legal), abort.
319     if (!doCanonicalization ||
320         !inversePermutation(concatAffineMaps(newIndexingMaps)))
321       return failure();
322 
323     // If any operand type change, insert a reshape to convert from the original
324     // type to the new type.
325     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
326     unsigned flattenedIdx = 0;
327     auto insertReshapes = [&](ValueRange values) {
328       SmallVector<Value, 4> res;
329       res.reserve(values.size());
330       for (auto operand : llvm::enumerate(values)) {
331         if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
332           res.push_back(operand.value());
333         else
334           res.push_back(rewriter.create<linalg::TensorReshapeOp>(
335               loc, newInputOutputTypes[flattenedIdx], operand.value(),
336               reassociationMaps[flattenedIdx]));
337         ++flattenedIdx;
338       }
339       return res;
340     };
341 
342     SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
343     SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs());
344 
345     // If any result type changes, insert a reshape to convert from the original
346     // type to the new type.
347     SmallVector<Type, 4> resultTypes;
348     resultTypes.reserve(op.getNumResults());
349     for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
350       resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
351     GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
352         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
353         llvm::to_vector<4>(
354             op.iterator_types().template getAsValueRange<StringAttr>()));
355     rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
356                                 replacementOp.region().begin());
357 
358     // If any result tensor has a modified shape, then add reshape to recover
359     // the original shape.
360     SmallVector<Value, 4> resultReplacements;
361     for (auto result : llvm::enumerate(replacementOp.getResults())) {
362       unsigned index = result.index() + replacementOp.getNumInputs();
363       RankedTensorType origResultType = op.getResult(result.index())
364                                             .getType()
365                                             .template cast<RankedTensorType>();
366       if (origResultType != result.value().getType())
367         resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
368             loc, origResultType, result.value(), reassociationMaps[index]));
369       else
370         resultReplacements.push_back(result.value());
371     }
372     rewriter.replaceOp(op, resultReplacements);
373     return success();
374   }
375 };
376 
377 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
378 /// example:
379 ///
380 ///  %0 = linalg.tensor_reshape %arg0
381 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
382 ///    : tensor<2048xf32> into tensor<1x4x1x512xf32>
383 ///  %1 = linalg.tensor_reshape %0
384 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
385 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
386 ///    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
387 ///
388 /// can be replaced with
389 ///
390 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
391 ///    : tensor<2048xf32> into tensor<4x512xf32>
392 ///
393 /// Similarly,
394 ///
395 ///  %0 = linalg.tensor_reshape %arg0
396 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
397 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
398 ///    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
399 ///  %1 = linalg.tensor_reshape %0
400 ///   [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
401 ///    : tensor<1x4x1x512xf32> into tensor<2048xf32>
402 ///
403 /// can be replaced with
404 ///
405 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
406 ///    : tensor<4x512xf32> into tensor<2048xf32>
407 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
408   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
409 
410   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
411                                 PatternRewriter &rewriter) const override {
412     // Check that the source operand is created from a reshape as well.
413     TensorReshapeOp parentReshapeOp =
414         reshapeOp.src().getDefiningOp<TensorReshapeOp>();
415     if (!parentReshapeOp)
416       return failure();
417 
418     RankedTensorType srcType = reshapeOp.getSrcType(),
419                      dstType = reshapeOp.getResultType(),
420                      parentSrcType = parentReshapeOp.getSrcType();
421     if (!srcType.hasStaticShape() || !dstType.hasStaticShape() ||
422         !parentSrcType.hasStaticShape() ||
423         srcType.getRank() < dstType.getRank() ||
424         parentSrcType.getRank() == dstType.getRank())
425       return failure();
426 
427     // Check if the result tensor_reshape is folding or expanding after folding
428     // the reshapeOp and parentReshapeOp are combined.  If the final
429     // tensor_reshape is folding, the parentReshapeOp is introducing unit-dims,
430     // and the reshapeOp does an actual reshape.  If the final tensor_reshape op
431     // is expanding, the reshapeOp is introducing unit-dims, and the
432     // parentReshapeOp does an actual reshape.
433     bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
434     ArrayRef<int64_t> expandedShape =
435         isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
436     ArrayRef<int64_t> foldedShape =
437         isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
438 
439     unsigned expandedDim = 0, foldedDim = 0;
440     SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
441         foldedShape.size());
442     while (expandedDim < expandedShape.size() &&
443            foldedDim < foldedShape.size()) {
444       int64_t dstSize = foldedShape[foldedDim];
445       int64_t srcSize = expandedShape[expandedDim];
446       while (srcSize < dstSize && expandedDim < expandedShape.size()) {
447         reassociationExprs[foldedDim].push_back(
448             rewriter.getAffineDimExpr(expandedDim++));
449         srcSize *= expandedShape[expandedDim];
450       }
451       if (srcSize == dstSize) {
452         reassociationExprs[foldedDim].push_back(
453             rewriter.getAffineDimExpr(expandedDim++));
454         // If the next dim in foldedShape is not 1, treat subsequent dims in
455         // expandedShape which are 1 to be collapsed.
456         if (foldedDim == foldedShape.size() - 1 ||
457             foldedShape[foldedDim + 1] != 1) {
458           while (expandedDim < expandedShape.size() &&
459                  expandedShape[expandedDim] == 1) {
460             reassociationExprs[foldedDim].push_back(
461                 rewriter.getAffineDimExpr(expandedDim++));
462           }
463         }
464       } else {
465         return failure();
466       }
467       foldedDim++;
468     }
469     if (expandedDim != expandedShape.size())
470       return failure();
471 
472     SmallVector<AffineMap, 4> reassociationMaps =
473         llvm::to_vector<4>(llvm::map_range(
474             reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
475               return AffineMap::get(expandedShape.size(), 0, exprs,
476                                     rewriter.getContext());
477             }));
478     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
479         reshapeOp, dstType, parentReshapeOp.src(),
480         rewriter.getAffineMapArrayAttr(reassociationMaps));
481     return success();
482   }
483 };
484 
485 /// Pattern to fold subtensors that are just taking a slice of unit-dimension
486 /// tensor. For example
487 ///
488 /// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1]
489 ///     : tensor<1x?x1xf32> to tensor<1x?x1xf32>
490 ///
491 /// can be replaced with
492 ///
493 /// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
494 ///     : tensor<1x?x1xf32> into tensor<?xf32>
495 /// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32>
496 /// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
497 ///     : tensor<?xf32> into tensor<1x?x1xf32>
498 ///
499 /// The additional tensor_reshapes will hopefully get canonicalized away with
500 /// other reshapes that drop unit dimensions. Three condiitions to fold a
501 /// dimension
502 /// - The offset must be 0
503 /// - The size must be 1
504 /// - The dimension of the source type must be 1.
505 struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
506   using OpRewritePattern<SubTensorOp>::OpRewritePattern;
507 
508   LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
509                                 PatternRewriter &rewriter) const override {
510     SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets();
511     SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes();
512     SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides();
513     auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) {
514       auto attr = valueOrAttr.dyn_cast<Attribute>();
515       return attr && attr.cast<IntegerAttr>().getInt() == val;
516     };
517 
518     if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) {
519           return !hasValue(valueOrAttr, 1);
520         }))
521       return failure();
522 
523     // Find the expanded unit dimensions.
524     SmallVector<ReassociationIndices> reassociation;
525     SmallVector<OpFoldResult> newOffsets, newSizes;
526     ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape();
527     ReassociationIndices curr;
528     for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
529       curr.push_back(dim);
530       if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) &&
531           hasValue(mixedSizes[dim], 1)) {
532         continue;
533       }
534       newOffsets.push_back(mixedOffsets[dim]);
535       newSizes.push_back(mixedSizes[dim]);
536       reassociation.emplace_back(ReassociationIndices{});
537       std::swap(reassociation.back(), curr);
538     }
539     if (newOffsets.size() == mixedOffsets.size())
540       return failure();
541     reassociation.back().append(curr.begin(), curr.end());
542     SmallVector<OpFoldResult> newStrides(newOffsets.size(),
543                                          rewriter.getI64IntegerAttr(1));
544     Location loc = subTensorOp->getLoc();
545     auto srcReshape = rewriter.create<TensorReshapeOp>(
546         loc, subTensorOp.source(), reassociation);
547     auto newSubTensorOp = rewriter.create<SubTensorOp>(
548         loc, srcReshape, newOffsets, newSizes, newStrides);
549     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
550         subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation);
551     return success();
552   }
553 };
554 
555 } // namespace
556 
557 /// Patterns that are used to canonicalize the use of unit-extent dims for
558 /// broadcasting.
559 void mlir::populateLinalgFoldUnitExtentDimsPatterns(
560     RewritePatternSet &patterns) {
561   auto *context = patterns.getContext();
562   patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
563                FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>,
564                ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
565   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
566   patterns.add<FoldReshapeOpWithUnitExtent>(context);
567   populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
568 }
569 
570 namespace {
571 /// Pass that removes unit-extent dims within generic ops.
572 struct LinalgFoldUnitExtentDimsPass
573     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
574   void runOnFunction() override {
575     FuncOp funcOp = getFunction();
576     MLIRContext *context = funcOp.getContext();
577     RewritePatternSet patterns(context);
578     if (foldOneTripLoopsOnly)
579       patterns
580           .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
581               context);
582     else
583       populateLinalgFoldUnitExtentDimsPatterns(patterns);
584     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
585   }
586 };
587 } // namespace
588 
589 std::unique_ptr<OperationPass<FuncOp>>
590 mlir::createLinalgFoldUnitExtentDimsPass() {
591   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
592 }
593