1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "linalg-drop-unit-dims"
30 
31 using namespace mlir;
32 using namespace mlir::edsc;
33 using namespace mlir::edsc::intrinsics;
34 using namespace mlir::linalg;
35 
36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
37 /// broadcasting. For example,
38 ///
39 /// ```mlir
40 /// #accesses = [
41 ///   affine_map<(d0, d1) -> (0, d1)>,
42 ///   affine_map<(d0, d1) -> (d0, 0)>,
43 ///   affine_map<(d0, d1) -> (d0, d1)>
44 /// ]
45 ///
46 /// #trait = {
47 ///   args_in = 2,
48 ///   args_out = 1,
49 ///   indexing_maps = #accesses,
50 ///   iterator_types = ["parallel", "parallel"],
51 ///   library_call = "some_external_fn"
52 /// }
53 ///
54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
55 /// tensor<5x5xf32>
56 /// {
57 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
58 ///        tensor<5xf32> into tensor<1x5xf32>
59 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
60 ///        tensor<5xf32> into tensor<5x1xf32>
61 ///   %2 = linalg.generic #trait %0, %1 {
62 ///        ^bb0(%arg2: f32, %arg3: f32):
63 ///          %3 = addf %arg2, %arg3 : f32
64 ///          linalg.yield %3 : f32
65 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
66 ///   return %2 : tensor<5x5xf32>
67 /// }
68 ///
69 /// would canonicalize to
70 ///
71 /// ```mlir
72 /// #accesses = [
73 ///   affine_map<(d0, d1) -> (d1)>,
74 ///   affine_map<(d0, d1) -> (d0)>,
75 ///   affine_map<(d0, d1) -> (d0, d1)>
76 /// ]
77 ///
78 /// #trait = {
79 ///   args_in = 2,
80 ///   args_out = 1,
81 ///   indexing_maps = #accesses,
82 ///   iterator_types = ["parallel", "parallel"],
83 ///   library_call = "some_external_fn"
84 /// }
85 ///
86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
87 /// tensor<5x5xf32>
88 /// {
89 ///   %0 = linalg.generic #trait %arg0, %arg1 {
90 ///        ^bb0(%arg2: f32, %arg3: f32):
91 ///          %3 = addf %arg2, %arg3 : f32
92 ///          linalg.yield %3 : f32
93 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
94 ///   return %0 : tensor<5x5xf32>
95 /// }
96 
97 /// Given dims of the iteration space of a structured op that are known to be
98 /// single trip count (`unitDims`), return the indexing maps to use in the
99 /// canonicalized op with these dims removed, given the original `indexingMaps`.
100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
101                                  ArrayRef<AffineMap> indexingMaps,
102                                  MLIRContext *context) {
103   if (indexingMaps.empty())
104     return nullptr;
105   unsigned numIterationDims = indexingMaps.front().getNumDims();
106   unsigned numSymbols = indexingMaps.front().getNumSymbols();
107 
108   // Compute the replacement for each dim expr.
109   SmallVector<AffineExpr, 4> dimReplacements;
110   dimReplacements.reserve(numIterationDims);
111   unsigned numKeptDims = 0;
112   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
113     if (unitDims.count(dim))
114       dimReplacements.push_back(getAffineConstantExpr(0, context));
115     else
116       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
117   }
118 
119   // Symbols remain the same.
120   SmallVector<AffineExpr, 4> symReplacements;
121   symReplacements.reserve(numSymbols);
122   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
123     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
124 
125   SmallVector<AffineMap, 4> newIndexingMaps;
126   newIndexingMaps.reserve(indexingMaps.size());
127   for (AffineMap operandMap : indexingMaps) {
128     // Expected indexing maps to have no symbols.
129     if (operandMap.getNumSymbols())
130       return nullptr;
131     newIndexingMaps.push_back(simplifyAffineMap(
132         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
133                                          numIterationDims - unitDims.size(),
134                                          numSymbols)));
135   }
136 
137   // Check that the new index maps are invertible. If not, something went
138   // wrong, so abort.
139   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
140     return nullptr;
141   return ArrayAttr::get(context,
142                         llvm::to_vector<4>(llvm::map_range(
143                             newIndexingMaps, [](AffineMap map) -> Attribute {
144                               return AffineMapAttr::get(map);
145                             })));
146 }
147 
148 /// Modify the region of indexed generic op to drop arguments corresponding to
149 /// loops that are unit trip count.
150 template <typename OpTy>
151 static LogicalResult
152 replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
153                                PatternRewriter &rewriterp) {
154   return success();
155 }
156 
157 template <>
158 LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
159     IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
160     PatternRewriter &rewriter) {
161   OpBuilder::InsertionGuard guard(rewriter);
162   Block *entryBlock = &op->getRegion(0).front();
163   rewriter.setInsertionPointToStart(entryBlock);
164   Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
165   for (unsigned unitDimLoop : unitDims) {
166     entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
167   }
168   SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
169   entryBlock->eraseArguments(unitDimsToErase);
170   return success();
171 }
172 
173 namespace {
174 /// Pattern to fold unit-trip count loops in GenericOps.
175 template <typename GenericOpTy>
176 struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
177   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
178   LogicalResult matchAndRewrite(GenericOpTy op,
179                                 PatternRewriter &rewriter) const override {
180     // TODO: remove once index ops are supported.
181     if (op.hasIndexSemantics())
182       return failure();
183 
184     SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
185     if (indexingMaps.empty())
186       return failure();
187 
188     // Check if any of the iteration dimensions are unit-trip count. They will
189     // end up being unit-trip count if they are used to index into a unit-dim
190     // tensor/memref.
191     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
192     if (!invertedMap)
193       return failure();
194     SmallVector<int64_t, 4> dims;
195     for (ShapedType shapedType : op.getShapedOperandTypes())
196       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
197 
198     // Find all the reduction iterators. Those need some special consideration
199     // (see below).
200     auto getLoopDimsOfType =
201         [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
202       SmallVector<AffineExpr> dimExprs;
203       getDimsOfType(op, iteratorTypeName, dimExprs);
204       return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
205         return expr.cast<AffineDimExpr>().getPosition();
206       }));
207     };
208     auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
209 
210     DenseSet<unsigned> unitDims;
211     SmallVector<unsigned, 4> unitDimsReductionLoops;
212     ArrayAttr iteratorTypes = op.iterator_types();
213     for (auto expr : enumerate(invertedMap.getResults())) {
214       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
215         if (dims[dimExpr.getPosition()] == 1) {
216           if (isParallelIterator(iteratorTypes[expr.index()]))
217             unitDims.insert(expr.index());
218           else if (isReductionIterator(iteratorTypes[expr.index()]))
219             unitDimsReductionLoops.push_back(expr.index());
220         }
221     }
222 
223     // Reduction loops can be dropped if there is at least one other reduction
224     // loop that is not dropped. This accounts for the initial value read in the
225     // reduction loop.
226     if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
227       if (unitDimsReductionLoops.size() == reductionDims.size())
228         unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
229       else
230         unitDims.insert(unitDimsReductionLoops.begin(),
231                         unitDimsReductionLoops.end());
232     }
233 
234     if (unitDims.empty())
235       return failure();
236 
237     // Compute the modified indexing maps.
238     MLIRContext *context = rewriter.getContext();
239     ArrayAttr newIndexingMapAttr =
240         replaceUnitDims(unitDims, indexingMaps, context);
241     if (!newIndexingMapAttr)
242       return op.emitError("unable to compute modified indexing_maps");
243 
244     // Compute the iterator types of the modified op by dropping the one-trip
245     // count loops.
246     SmallVector<Attribute, 4> newIteratorTypes;
247     for (auto attr : llvm::enumerate(iteratorTypes)) {
248       if (!unitDims.count(attr.index()))
249         newIteratorTypes.push_back(attr.value());
250     }
251 
252     rewriter.startRootUpdate(op);
253     op.indexing_mapsAttr(newIndexingMapAttr);
254     op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
255     (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
256     rewriter.finalizeRootUpdate(op);
257     return success();
258   }
259 };
260 
261 struct UnitExtentReplacementInfo {
262   RankedTensorType type;
263   AffineMap indexMap;
264   ArrayAttr reassociation;
265 };
266 } // namespace
267 
268 /// Utility function for replacing operands/results to a linalg generic
269 /// operation on tensors with unit-extent dimensions. These can be replaced with
270 /// an operand/result with the unit-extent dimension removed. This is only done
271 /// if the indexing map used to access that didimensionmension has a
272 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
273 /// Linalg op, and its `indexMap` the utility function returns:
274 /// - the new type with dimensions of size 1 removed.
275 /// - modified index map that can be used to access the replaced result/operand
276 /// - the reassociation that converts from the original tensor type to the
277 ///   modified tensor type.
278 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
279                                                     RankedTensorType type,
280                                                     MLIRContext *context) {
281   ArrayRef<int64_t> shape = type.getShape();
282   ArrayRef<AffineExpr> exprs = indexMap.getResults();
283   SmallVector<AffineExpr, 2> reassociations;
284   SmallVector<Attribute, 4> reassociationMaps;
285   SmallVector<AffineExpr, 4> newIndexExprs;
286   SmallVector<int64_t, 4> newShape;
287 
288   int64_t origRank = type.getRank();
289   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
290   auto isUnitExtent = [&](int64_t dim) -> bool {
291     return shape[dim] == 1 && exprs[dim] == zeroExpr;
292   };
293 
294   unsigned dim = 0;
295   // Fold dimensions that are unit-extent at the beginning of the tensor.
296   while (dim < origRank && isUnitExtent(dim))
297     reassociations.push_back(getAffineDimExpr(dim++, context));
298   while (dim < origRank) {
299     reassociations.push_back(getAffineDimExpr(dim, context));
300     newIndexExprs.push_back(exprs[dim]);
301     newShape.push_back(shape[dim]);
302     // Fold all following dimensions that are unit-extent.
303     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
304       ++dim;
305       reassociations.push_back(getAffineDimExpr(dim, context));
306     }
307     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
308         origRank, /*numSymbols = */ 0, reassociations, context)));
309     reassociations.clear();
310     ++dim;
311   }
312   UnitExtentReplacementInfo info = {
313       RankedTensorType::get(newShape, type.getElementType()),
314       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
315                      newIndexExprs, context),
316       ArrayAttr::get(context, reassociationMaps)};
317   return info;
318 }
319 
320 namespace {
321 
322 /// Pattern to replace tensors operands/results that are unit extents.
323 template <typename GenericOpTy>
324 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
325   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
326   LogicalResult matchAndRewrite(GenericOpTy op,
327                                 PatternRewriter &rewriter) const override {
328     // TODO: remove once index ops are supported.
329     if (op.hasIndexSemantics())
330       return failure();
331 
332     if (!op.hasTensorSemantics())
333       return failure();
334 
335     MLIRContext *context = rewriter.getContext();
336     Location loc = op.getLoc();
337 
338     SmallVector<AffineMap, 4> newIndexingMaps;
339     SmallVector<ArrayAttr, 4> reassociationMaps;
340     SmallVector<ShapedType, 4> newInputOutputTypes;
341     bool doCanonicalization = false;
342     for (auto it :
343          llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) {
344       auto replacementInfo = replaceUnitExtents(
345           std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
346           context);
347       reassociationMaps.push_back(replacementInfo.reassociation);
348       newIndexingMaps.push_back(replacementInfo.indexMap);
349       newInputOutputTypes.push_back(replacementInfo.type);
350       doCanonicalization |= replacementInfo.type != std::get<1>(it);
351     }
352 
353     // If the indexing maps of the result operation are not invertible (i.e. not
354     // legal), abort.
355     if (!doCanonicalization ||
356         !inversePermutation(concatAffineMaps(newIndexingMaps)))
357       return failure();
358 
359     // If any operand type change, insert a reshape to convert from the original
360     // type to the new type.
361     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
362     unsigned flattenedIdx = 0;
363     auto insertReshapes = [&](ValueRange values) {
364       SmallVector<Value, 4> res;
365       res.reserve(values.size());
366       for (auto operand : llvm::enumerate(values)) {
367         if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
368           res.push_back(operand.value());
369         else
370           res.push_back(rewriter.create<linalg::TensorReshapeOp>(
371               loc, newInputOutputTypes[flattenedIdx], operand.value(),
372               reassociationMaps[flattenedIdx]));
373         ++flattenedIdx;
374       }
375       return res;
376     };
377 
378     SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
379     SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs());
380 
381     // If any result type changes, insert a reshape to convert from the original
382     // type to the new type.
383     SmallVector<Type, 4> resultTypes;
384     resultTypes.reserve(op.getNumResults());
385     for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
386       resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
387     GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
388         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
389         llvm::to_vector<4>(
390             op.iterator_types().template getAsValueRange<StringAttr>()));
391     rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
392                                 replacementOp.region().begin());
393 
394     // If any result tensor has a modified shape, then add reshape to recover
395     // the original shape.
396     SmallVector<Value, 4> resultReplacements;
397     for (auto result : llvm::enumerate(replacementOp.getResults())) {
398       unsigned index = result.index() + replacementOp.getNumInputs();
399       RankedTensorType origResultType = op.getResult(result.index())
400                                             .getType()
401                                             .template cast<RankedTensorType>();
402       if (origResultType != result.value().getType())
403         resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
404             loc, origResultType, result.value(), reassociationMaps[index]));
405       else
406         resultReplacements.push_back(result.value());
407     }
408     rewriter.replaceOp(op, resultReplacements);
409     return success();
410   }
411 };
412 
413 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
414 /// example:
415 ///
416 ///  %0 = linalg.tensor_reshape %arg0
417 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
418 ///    : tensor<2048xf32> into tensor<1x4x1x512xf32>
419 ///  %1 = linalg.tensor_reshape %0
420 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
421 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
422 ///    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
423 ///
424 /// can be replaced with
425 ///
426 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
427 ///    : tensor<2048xf32> into tensor<4x512xf32>
428 ///
429 /// Similarly,
430 ///
431 ///  %0 = linalg.tensor_reshape %arg0
432 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
433 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
434 ///    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
435 ///  %1 = linalg.tensor_reshape %0
436 ///   [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
437 ///    : tensor<1x4x1x512xf32> into tensor<2048xf32>
438 ///
439 /// can be replaced with
440 ///
441 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
442 ///    : tensor<4x512xf32> into tensor<2048xf32>
443 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
444   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
445 
446   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
447                                 PatternRewriter &rewriter) const override {
448     // Check that the source operand is created from a reshape as well.
449     TensorReshapeOp parentReshapeOp =
450         reshapeOp.src().getDefiningOp<TensorReshapeOp>();
451     if (!parentReshapeOp)
452       return failure();
453 
454     RankedTensorType srcType = reshapeOp.getSrcType(),
455                      dstType = reshapeOp.getResultType(),
456                      parentSrcType = parentReshapeOp.getSrcType();
457     if (!srcType.hasStaticShape() || !dstType.hasStaticShape() ||
458         !parentSrcType.hasStaticShape() ||
459         srcType.getRank() < dstType.getRank() ||
460         parentSrcType.getRank() == dstType.getRank())
461       return failure();
462 
463     // Check if the result tensor_reshape is folding or expanding after folding
464     // the reshapeOp and parentReshapeOp are combined.  If the final
465     // tensor_reshape is folding, the parentReshapeOp is introducing unit-dims,
466     // and the reshapeOp does an actual reshape.  If the final tensor_reshape op
467     // is expanding, the reshapeOp is introducing unit-dims, and the
468     // parentReshapeOp does an actual reshape.
469     bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
470     ArrayRef<int64_t> expandedShape =
471         isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
472     ArrayRef<int64_t> foldedShape =
473         isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
474 
475     unsigned expandedDim = 0, foldedDim = 0;
476     SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
477         foldedShape.size());
478     while (expandedDim < expandedShape.size() &&
479            foldedDim < foldedShape.size()) {
480       int64_t dstSize = foldedShape[foldedDim];
481       int64_t srcSize = expandedShape[expandedDim];
482       while (srcSize < dstSize && expandedDim < expandedShape.size()) {
483         reassociationExprs[foldedDim].push_back(
484             rewriter.getAffineDimExpr(expandedDim++));
485         srcSize *= expandedShape[expandedDim];
486       }
487       if (srcSize == dstSize) {
488         reassociationExprs[foldedDim].push_back(
489             rewriter.getAffineDimExpr(expandedDim++));
490         // If the next dim in foldedShape is not 1, treat subsequent dims in
491         // expandedShape which are 1 to be collapsed.
492         if (foldedDim == foldedShape.size() - 1 ||
493             foldedShape[foldedDim + 1] != 1) {
494           while (expandedDim < expandedShape.size() &&
495                  expandedShape[expandedDim] == 1) {
496             reassociationExprs[foldedDim].push_back(
497                 rewriter.getAffineDimExpr(expandedDim++));
498           }
499         }
500       } else {
501         return failure();
502       }
503       foldedDim++;
504     }
505     if (expandedDim != expandedShape.size())
506       return failure();
507 
508     SmallVector<AffineMap, 4> reassociationMaps =
509         llvm::to_vector<4>(llvm::map_range(
510             reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
511               return AffineMap::get(expandedShape.size(), 0, exprs,
512                                     rewriter.getContext());
513             }));
514     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
515         reshapeOp, dstType, parentReshapeOp.src(),
516         rewriter.getAffineMapArrayAttr(reassociationMaps));
517     return success();
518   }
519 };
520 
521 /// Pattern to fold subtensors that are just taking a slice of unit-dimension
522 /// tensor. For example
523 ///
524 /// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1]
525 ///     : tensor<1x?x1xf32> to tensor<1x?x1xf32>
526 ///
527 /// can be replaced with
528 ///
529 /// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
530 ///     : tensor<1x?x1xf32> into tensor<?xf32>
531 /// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32>
532 /// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
533 ///     : tensor<?xf32> into tensor<1x?x1xf32>
534 ///
535 /// The additional tensor_reshapes will hopefully get canonicalized away with
536 /// other reshapes that drop unit dimensions. Three condiitions to fold a
537 /// dimension
538 /// - The offset must be 0
539 /// - The size must be 1
540 /// - The dimension of the source type must be 1.
541 struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
542   using OpRewritePattern<SubTensorOp>::OpRewritePattern;
543 
544   LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
545                                 PatternRewriter &rewriter) const override {
546     SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets();
547     SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes();
548     SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides();
549     auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) {
550       auto attr = valueOrAttr.dyn_cast<Attribute>();
551       return attr && attr.cast<IntegerAttr>().getInt() == val;
552     };
553 
554     if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) {
555           return !hasValue(valueOrAttr, 1);
556         }))
557       return failure();
558 
559     // Find the expanded unit dimensions.
560     SmallVector<ReassociationIndices> reassociation;
561     SmallVector<OpFoldResult> newOffsets, newSizes;
562     ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape();
563     ReassociationIndices curr;
564     for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
565       curr.push_back(dim);
566       if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) &&
567           hasValue(mixedSizes[dim], 1)) {
568         continue;
569       }
570       newOffsets.push_back(mixedOffsets[dim]);
571       newSizes.push_back(mixedSizes[dim]);
572       reassociation.emplace_back(ReassociationIndices{});
573       std::swap(reassociation.back(), curr);
574     }
575     if (newOffsets.size() == mixedOffsets.size())
576       return failure();
577     reassociation.back().append(curr.begin(), curr.end());
578     SmallVector<OpFoldResult> newStrides(newOffsets.size(),
579                                          rewriter.getI64IntegerAttr(1));
580     Location loc = subTensorOp->getLoc();
581     auto srcReshape = rewriter.create<TensorReshapeOp>(
582         loc, subTensorOp.source(), reassociation);
583     auto newSubTensorOp = rewriter.create<SubTensorOp>(
584         loc, srcReshape, newOffsets, newSizes, newStrides);
585     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
586         subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation);
587     return success();
588   }
589 };
590 
591 } // namespace
592 
593 /// Patterns that are used to canonicalize the use of unit-extent dims for
594 /// broadcasting.
595 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
596     RewritePatternSet &patterns) {
597   auto *context = patterns.getContext();
598   patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
599                FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>,
600                ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
601   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
602   patterns.add<FoldReshapeOpWithUnitExtent>(context);
603 }
604 
605 namespace {
606 /// Pass that removes unit-extent dims within generic ops.
607 struct LinalgFoldUnitExtentDimsPass
608     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
609   void runOnFunction() override {
610     FuncOp funcOp = getFunction();
611     MLIRContext *context = funcOp.getContext();
612     RewritePatternSet patterns(context);
613     if (foldOneTripLoopsOnly)
614       patterns
615           .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
616               context);
617     else
618       populateFoldUnitExtentDimsPatterns(patterns);
619     (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
620   }
621 };
622 } // namespace
623 
624 std::unique_ptr<OperationPass<FuncOp>>
625 mlir::createLinalgFoldUnitExtentDimsPass() {
626   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
627 }
628