1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Linalg/IR/Linalg.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/Transforms/FoldUtils.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29
30 #define DEBUG_TYPE "linalg-drop-unit-dims"
31
32 using namespace mlir;
33 using namespace mlir::linalg;
34
35 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
36 /// broadcasting. For example,
37 ///
38 /// ```mlir
39 /// #accesses = [
40 /// affine_map<(d0, d1) -> (0, d1)>,
41 /// affine_map<(d0, d1) -> (d0, 0)>,
42 /// affine_map<(d0, d1) -> (d0, d1)>
43 /// ]
44 ///
45 /// #trait = {
46 /// args_in = 2,
47 /// args_out = 1,
48 /// indexing_maps = #accesses,
49 /// iterator_types = ["parallel", "parallel"],
50 /// library_call = "some_external_fn"
51 /// }
52 ///
53 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
54 /// tensor<5x5xf32>
55 /// {
56 /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
57 /// tensor<5xf32> into tensor<1x5xf32>
58 /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
59 /// tensor<5xf32> into tensor<5x1xf32>
60 /// %2 = linalg.generic #trait %0, %1 {
61 /// ^bb0(%arg2: f32, %arg3: f32):
62 /// %3 = arith.addf %arg2, %arg3 : f32
63 /// linalg.yield %3 : f32
64 /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
65 /// return %2 : tensor<5x5xf32>
66 /// }
67 ///
68 /// would canonicalize to
69 ///
70 /// ```mlir
71 /// #accesses = [
72 /// affine_map<(d0, d1) -> (d1)>,
73 /// affine_map<(d0, d1) -> (d0)>,
74 /// affine_map<(d0, d1) -> (d0, d1)>
75 /// ]
76 ///
77 /// #trait = {
78 /// args_in = 2,
79 /// args_out = 1,
80 /// indexing_maps = #accesses,
81 /// iterator_types = ["parallel", "parallel"],
82 /// library_call = "some_external_fn"
83 /// }
84 ///
85 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
86 /// tensor<5x5xf32>
87 /// {
88 /// %0 = linalg.generic #trait %arg0, %arg1 {
89 /// ^bb0(%arg2: f32, %arg3: f32):
90 /// %3 = arith.addf %arg2, %arg3 : f32
91 /// linalg.yield %3 : f32
92 /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
93 /// return %0 : tensor<5x5xf32>
94 /// }
95
96 /// Given dims of the iteration space of a structured op that are known to be
97 /// single trip count (`unitDims`), return the indexing maps to use in the
98 /// canonicalized op with these dims removed, given the original `indexingMaps`.
replaceUnitDims(DenseSet<unsigned> & unitDims,ArrayRef<AffineMap> indexingMaps,MLIRContext * context)99 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
100 ArrayRef<AffineMap> indexingMaps,
101 MLIRContext *context) {
102 if (indexingMaps.empty())
103 return nullptr;
104 unsigned numIterationDims = indexingMaps.front().getNumDims();
105 unsigned numSymbols = indexingMaps.front().getNumSymbols();
106
107 // Compute the replacement for each dim expr.
108 SmallVector<AffineExpr, 4> dimReplacements;
109 dimReplacements.reserve(numIterationDims);
110 unsigned numKeptDims = 0;
111 for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
112 if (unitDims.count(dim))
113 dimReplacements.push_back(getAffineConstantExpr(0, context));
114 else
115 dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
116 }
117
118 // Symbols remain the same.
119 SmallVector<AffineExpr, 4> symReplacements;
120 symReplacements.reserve(numSymbols);
121 for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
122 symReplacements.push_back(getAffineSymbolExpr(symbol, context));
123
124 SmallVector<AffineMap, 4> newIndexingMaps;
125 newIndexingMaps.reserve(indexingMaps.size());
126 for (AffineMap operandMap : indexingMaps) {
127 // Expected indexing maps to have no symbols.
128 if (operandMap.getNumSymbols())
129 return nullptr;
130 newIndexingMaps.push_back(simplifyAffineMap(
131 operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
132 numIterationDims - unitDims.size(),
133 numSymbols)));
134 }
135
136 // Check that the new index maps are invertible. If not, something went
137 // wrong, so abort.
138 if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
139 return nullptr;
140 return ArrayAttr::get(context,
141 llvm::to_vector<4>(llvm::map_range(
142 newIndexingMaps, [](AffineMap map) -> Attribute {
143 return AffineMapAttr::get(map);
144 })));
145 }
146
147 /// Update the index accesses of linalg operations having index semantics.
replaceUnitDimIndexOps(GenericOp genericOp,const DenseSet<unsigned> & unitDims,PatternRewriter & rewriter)148 static void replaceUnitDimIndexOps(GenericOp genericOp,
149 const DenseSet<unsigned> &unitDims,
150 PatternRewriter &rewriter) {
151 for (IndexOp indexOp :
152 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
153 OpBuilder::InsertionGuard guard(rewriter);
154 rewriter.setInsertionPoint(indexOp);
155 if (unitDims.count(indexOp.dim()) != 0) {
156 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
157 } else {
158 // Update the dimension of the index operation if needed.
159 unsigned droppedDims = llvm::count_if(
160 unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
161 if (droppedDims != 0)
162 rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
163 indexOp.dim() - droppedDims);
164 }
165 }
166 }
167
168 namespace {
169 /// Pattern to fold unit-trip count loops in GenericOps.
170 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
171 using OpRewritePattern<GenericOp>::OpRewritePattern;
matchAndRewrite__anon1bf028040311::FoldUnitDimLoops172 LogicalResult matchAndRewrite(GenericOp genericOp,
173 PatternRewriter &rewriter) const override {
174 SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMapsArray();
175 if (indexingMaps.empty())
176 return failure();
177
178 // Check if any of the iteration dimensions are unit-trip count. They will
179 // end up being unit-trip count if they are used to index into a unit-dim
180 // tensor/memref.
181 AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
182 if (!invertedMap)
183 return failure();
184 SmallVector<int64_t> dims = genericOp.getStaticShape();
185
186 DenseSet<unsigned> unitDims;
187 SmallVector<unsigned, 4> unitDimsReductionLoops;
188 ArrayAttr iteratorTypes = genericOp.iterator_types();
189 for (const auto &expr : enumerate(invertedMap.getResults())) {
190 if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
191 if (dims[dimExpr.getPosition()] == 1)
192 unitDims.insert(expr.index());
193 }
194
195 if (unitDims.empty())
196 return failure();
197
198 // Compute the modified indexing maps.
199 MLIRContext *context = rewriter.getContext();
200 ArrayAttr newIndexingMapAttr =
201 replaceUnitDims(unitDims, indexingMaps, context);
202 if (!newIndexingMapAttr)
203 return genericOp.emitError("unable to compute modified indexing_maps");
204
205 // Compute the iterator types of the modified op by dropping the one-trip
206 // count loops.
207 SmallVector<Attribute, 4> newIteratorTypes;
208 for (const auto &attr : llvm::enumerate(iteratorTypes)) {
209 if (!unitDims.count(attr.index()))
210 newIteratorTypes.push_back(attr.value());
211 }
212
213 rewriter.startRootUpdate(genericOp);
214 genericOp.indexing_mapsAttr(newIndexingMapAttr);
215 genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
216 replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
217 rewriter.finalizeRootUpdate(genericOp);
218 return success();
219 }
220 };
221
222 struct UnitExtentReplacementInfo {
223 Type type;
224 AffineMap indexMap;
225 ArrayAttr reassociation;
226 };
227 } // namespace
228
229 /// Utility function for replacing operands/results to a linalg generic
230 /// operation with unit-extent dimensions. These can be replaced with
231 /// an operand/result with the unit-extent dimension removed. This is only done
232 /// if the indexing map used to access that didimensionmension has a
233 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
234 /// Linalg op, and its `indexMap` the utility function returns:
235 /// - the new type with dimensions of size 1 removed.
236 /// - modified index map that can be used to access the replaced result/operand
237 /// - the reassociation that converts from the original tensor type to the
238 /// modified tensor type.
239 static llvm::Optional<UnitExtentReplacementInfo>
replaceUnitExtents(GenericOp genericOp,OpOperand * opOperand,MLIRContext * context)240 replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
241 MLIRContext *context) {
242 AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
243 ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
244 ArrayRef<AffineExpr> exprs = indexingMap.getResults();
245 SmallVector<AffineExpr> reassociations;
246 SmallVector<Attribute> reassociationMaps;
247 SmallVector<AffineExpr> newIndexExprs;
248 SmallVector<int64_t> newShape;
249
250 int64_t origRank = genericOp.getRank(opOperand);
251 AffineExpr zeroExpr = getAffineConstantExpr(0, context);
252 auto isUnitExtent = [&](int64_t dim) -> bool {
253 return shape[dim] == 1 && exprs[dim] == zeroExpr;
254 };
255
256 // Early return for memrefs with affine maps to represent that we will always
257 // leave them unchanged.
258 Type actualType = opOperand->get().getType();
259 if (auto memref = actualType.dyn_cast<MemRefType>()) {
260 if (!memref.getLayout().isIdentity())
261 return llvm::None;
262 }
263
264 int64_t dim = 0;
265 // Fold dimensions that are unit-extent at the beginning of the tensor.
266 while (dim < origRank && isUnitExtent(dim))
267 reassociations.push_back(getAffineDimExpr(dim++, context));
268 while (dim < origRank) {
269 reassociations.push_back(getAffineDimExpr(dim, context));
270 newIndexExprs.push_back(exprs[dim]);
271 newShape.push_back(shape[dim]);
272 // Fold all following dimensions that are unit-extent.
273 while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
274 ++dim;
275 reassociations.push_back(getAffineDimExpr(dim, context));
276 }
277 reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
278 origRank, /*symbolCount = */ 0, reassociations, context)));
279 reassociations.clear();
280 ++dim;
281 }
282
283 // Compute the tensor or scalar replacement type.
284 Type elementType = getElementTypeOrSelf(opOperand->get());
285 Type replacementType;
286 if (elementType == opOperand->get().getType()) {
287 replacementType = elementType;
288 } else if (actualType.isa<RankedTensorType>()) {
289 replacementType = RankedTensorType::get(newShape, elementType);
290 } else if (actualType.isa<MemRefType>()) {
291 replacementType = MemRefType::get(newShape, elementType);
292 }
293 assert(replacementType && "unsupported shaped type");
294 UnitExtentReplacementInfo info = {replacementType,
295 AffineMap::get(indexingMap.getNumDims(),
296 indexingMap.getNumSymbols(),
297 newIndexExprs, context),
298 ArrayAttr::get(context, reassociationMaps)};
299 return info;
300 }
301
302 namespace {
303
304 SmallVector<ReassociationExprs, 2>
convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr)305 convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
306 SmallVector<ReassociationExprs, 2> reassociationExprs;
307 for (auto attr : affineMapArrayAttr)
308 reassociationExprs.push_back(
309 llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
310 return reassociationExprs;
311 }
312
313 /// Pattern to replace tensor/buffer operands/results that are unit extents.
314 struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
315 using OpRewritePattern<GenericOp>::OpRewritePattern;
316
317 // Return the original value if the type is unchanged, or reshape it. Return a
318 // nullptr if this is an unsupported type.
maybeExpand__anon1bf028040511::ReplaceUnitExtents319 Value maybeExpand(Value result, Type origResultType,
320 ArrayAttr reassociationMap, Location loc,
321 PatternRewriter &rewriter) const {
322 if (origResultType == result.getType())
323 return result;
324 if (origResultType.isa<RankedTensorType>()) {
325 return rewriter.create<tensor::ExpandShapeOp>(
326 loc, origResultType, result,
327 convertAffineMapArrayToExprs(reassociationMap));
328 }
329 if (origResultType.isa<MemRefType>()) {
330 return rewriter.create<memref::ExpandShapeOp>(
331 loc, origResultType, result,
332 convertAffineMapArrayToExprs(reassociationMap));
333 }
334 return nullptr;
335 };
336
337 // Return the original value if the type is unchanged, or reshape it. Return a
338 // nullptr if this is an unsupported type.
maybeCollapse__anon1bf028040511::ReplaceUnitExtents339 Value maybeCollapse(Value operand, Type newInputOutputType,
340 ArrayAttr reassociationMap, Location loc,
341 PatternRewriter &rewriter) const {
342 auto operandType = operand.getType();
343 if (operandType == newInputOutputType)
344 return operand;
345 if (operandType.isa<MemRefType>()) {
346 return rewriter.create<memref::CollapseShapeOp>(
347 loc, newInputOutputType, operand,
348 convertAffineMapArrayToExprs(reassociationMap));
349 }
350 if (operandType.isa<RankedTensorType>()) {
351 return rewriter.create<tensor::CollapseShapeOp>(
352 loc, newInputOutputType, operand,
353 convertAffineMapArrayToExprs(reassociationMap));
354 }
355 return nullptr;
356 };
357
matchAndRewrite__anon1bf028040511::ReplaceUnitExtents358 LogicalResult matchAndRewrite(GenericOp genericOp,
359 PatternRewriter &rewriter) const override {
360 // Skip the pattern if the op has any tensor with special encoding.
361 if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
362 auto tensorType = type.dyn_cast<RankedTensorType>();
363 return tensorType && tensorType.getEncoding() != nullptr;
364 }))
365 return failure();
366 MLIRContext *context = rewriter.getContext();
367 Location loc = genericOp.getLoc();
368
369 SmallVector<AffineMap> newIndexingMaps;
370 SmallVector<ArrayAttr> reassociationMaps;
371 SmallVector<Type> newInputOutputTypes;
372 bool doCanonicalization = false;
373 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
374 auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
375 if (replacementInfo) {
376 reassociationMaps.push_back(replacementInfo->reassociation);
377 newIndexingMaps.push_back(replacementInfo->indexMap);
378 newInputOutputTypes.push_back(replacementInfo->type);
379 doCanonicalization |=
380 replacementInfo->type != opOperand->get().getType();
381 } else {
382 // If replaceUnitExtents cannot handle this case, maintain the same
383 // type, indexing map, and create a set of mappings representing an
384 // identity matrix.
385 newInputOutputTypes.push_back(opOperand->get().getType());
386 newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
387 int64_t origRank = genericOp.getRank(opOperand);
388 auto maps = llvm::to_vector<8>(llvm::map_range(
389 llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
390 return AffineMapAttr::get(
391 AffineMap::get(origRank, /*symbolCount = */ 0,
392 getAffineDimExpr(dim, context), context));
393 }));
394 reassociationMaps.push_back(ArrayAttr::get(context, maps));
395 }
396 }
397
398 // If the indexing maps of the result operation are not invertible (i.e. not
399 // legal), abort.
400 if (!doCanonicalization ||
401 !inversePermutation(concatAffineMaps(newIndexingMaps)))
402 return failure();
403
404 // If any operand type change, insert a reshape to convert from the original
405 // type to the new type.
406 // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
407 unsigned flattenedIdx = 0;
408 auto insertReshapes = [&](ValueRange values) {
409 SmallVector<Value, 4> res;
410 res.reserve(values.size());
411 for (auto operand : values) {
412 auto reshapedValue =
413 maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
414 reassociationMaps[flattenedIdx], loc, rewriter);
415 assert(reshapedValue &&
416 "expected ranked MemRef or Tensor operand type");
417 res.push_back(reshapedValue);
418 ++flattenedIdx;
419 }
420 return res;
421 };
422
423 SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
424 SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
425
426 // If any result type changes, insert a reshape to convert from the original
427 // type to the new type.
428 SmallVector<Type, 4> resultTypes;
429 resultTypes.reserve(genericOp.getNumResults());
430 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
431 resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
432 GenericOp replacementOp = rewriter.create<GenericOp>(
433 loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
434 llvm::to_vector<4>(
435 genericOp.iterator_types().template getAsValueRange<StringAttr>()));
436 rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
437 replacementOp.region().begin());
438
439 // If any result tensor has a modified shape, then add reshape to recover
440 // the original shape.
441 SmallVector<Value, 4> resultReplacements;
442 for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
443 unsigned index = result.index() + replacementOp.getNumInputs();
444 auto origResultType = genericOp.getResult(result.index()).getType();
445
446 auto newResult = maybeExpand(result.value(), origResultType,
447 reassociationMaps[index], loc, rewriter);
448 assert(newResult &&
449 "unexpected output type other than ranked MemRef or Tensor");
450 resultReplacements.push_back(newResult);
451 }
452 rewriter.replaceOp(genericOp, resultReplacements);
453 return success();
454 }
455 };
456 } // namespace
457
458 namespace {
459 /// Convert `extract_slice` operations to rank-reduced versions.
460 struct RankReducedExtractSliceOp
461 : public OpRewritePattern<tensor::ExtractSliceOp> {
462 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
463
matchAndRewrite__anon1bf028040911::RankReducedExtractSliceOp464 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
465 PatternRewriter &rewriter) const override {
466 RankedTensorType resultType = sliceOp.getType();
467 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
468 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
469 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
470 auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
471 if (!reassociation ||
472 reassociation->size() == static_cast<size_t>(resultType.getRank()))
473 return failure();
474 auto rankReducedType =
475 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
476 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
477 strides)
478 .cast<RankedTensorType>();
479
480 Location loc = sliceOp.getLoc();
481 Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
482 loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
483 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
484 sliceOp, resultType, newSlice, *reassociation);
485 return success();
486 }
487 };
488
489 /// Convert `insert_slice` operations to rank-reduced versions.
490 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
491 template <typename InsertOpTy>
492 struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
493 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
494
matchAndRewrite__anon1bf028040911::RankReducedInsertSliceOp495 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
496 PatternRewriter &rewriter) const override {
497 RankedTensorType sourceType = insertSliceOp.getSourceType();
498 SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
499 SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
500 SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
501 auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
502 if (!reassociation ||
503 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
504 return failure();
505 Location loc = insertSliceOp.getLoc();
506 tensor::CollapseShapeOp reshapedSource;
507 {
508 OpBuilder::InsertionGuard g(rewriter);
509 // The only difference between InsertSliceOp and ParallelInsertSliceOp is
510 // the the insertion point is just before the ParallelCombiningOp in the
511 // parallel case.
512 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
513 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
514 reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
515 loc, insertSliceOp.getSource(), *reassociation);
516 }
517 rewriter.replaceOpWithNewOp<InsertOpTy>(
518 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
519 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
520 insertSliceOp.getMixedStrides());
521 return success();
522 }
523 };
524 } // namespace
525
526 /// Patterns that are used to canonicalize the use of unit-extent dims for
527 /// broadcasting.
populateFoldUnitExtentDimsPatterns(RewritePatternSet & patterns)528 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
529 RewritePatternSet &patterns) {
530 auto *context = patterns.getContext();
531 patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
532 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
533 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
534 context);
535 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
536 linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context);
537 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
538 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
539 }
540
541 namespace {
542 /// Pass that removes unit-extent dims within generic ops.
543 struct LinalgFoldUnitExtentDimsPass
544 : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
runOnOperation__anon1bf028040a11::LinalgFoldUnitExtentDimsPass545 void runOnOperation() override {
546 Operation *op = getOperation();
547 MLIRContext *context = op->getContext();
548 RewritePatternSet patterns(context);
549 if (foldOneTripLoopsOnly)
550 patterns.add<FoldUnitDimLoops>(context);
551 else
552 populateFoldUnitExtentDimsPatterns(patterns);
553 (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
554 }
555 };
556 } // namespace
557
createLinalgFoldUnitExtentDimsPass()558 std::unique_ptr<Pass> mlir::createLinalgFoldUnitExtentDimsPass() {
559 return std::make_unique<LinalgFoldUnitExtentDimsPass>();
560 }
561