1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Utils/Utils.h"
20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/FoldUtils.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 
29 #define DEBUG_TYPE "linalg-drop-unit-dims"
30 
31 using namespace mlir;
32 using namespace mlir::edsc;
33 using namespace mlir::edsc::intrinsics;
34 using namespace mlir::linalg;
35 
36 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
37 /// broadcasting. For example,
38 ///
39 /// ```mlir
40 /// #accesses = [
41 ///   affine_map<(d0, d1) -> (0, d1)>,
42 ///   affine_map<(d0, d1) -> (d0, 0)>,
43 ///   affine_map<(d0, d1) -> (d0, d1)>
44 /// ]
45 ///
46 /// #trait = {
47 ///   args_in = 2,
48 ///   args_out = 1,
49 ///   indexing_maps = #accesses,
50 ///   iterator_types = ["parallel", "parallel"],
51 ///   library_call = "some_external_fn"
52 /// }
53 ///
54 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
55 /// tensor<5x5xf32>
56 /// {
57 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
58 ///        tensor<5xf32> into tensor<1x5xf32>
59 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
60 ///        tensor<5xf32> into tensor<5x1xf32>
61 ///   %2 = linalg.generic #trait %0, %1 {
62 ///        ^bb0(%arg2: f32, %arg3: f32):
63 ///          %3 = addf %arg2, %arg3 : f32
64 ///          linalg.yield %3 : f32
65 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
66 ///   return %2 : tensor<5x5xf32>
67 /// }
68 ///
69 /// would canonicalize to
70 ///
71 /// ```mlir
72 /// #accesses = [
73 ///   affine_map<(d0, d1) -> (d1)>,
74 ///   affine_map<(d0, d1) -> (d0)>,
75 ///   affine_map<(d0, d1) -> (d0, d1)>
76 /// ]
77 ///
78 /// #trait = {
79 ///   args_in = 2,
80 ///   args_out = 1,
81 ///   indexing_maps = #accesses,
82 ///   iterator_types = ["parallel", "parallel"],
83 ///   library_call = "some_external_fn"
84 /// }
85 ///
86 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
87 /// tensor<5x5xf32>
88 /// {
89 ///   %0 = linalg.generic #trait %arg0, %arg1 {
90 ///        ^bb0(%arg2: f32, %arg3: f32):
91 ///          %3 = addf %arg2, %arg3 : f32
92 ///          linalg.yield %3 : f32
93 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
94 ///   return %0 : tensor<5x5xf32>
95 /// }
96 
97 /// Given dims of the iteration space of a structured op that are known to be
98 /// single trip count (`unitDims`), return the indexing maps to use in the
99 /// canonicalized op with these dims removed, given the original `indexingMaps`.
100 static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
101                                  ArrayRef<AffineMap> indexingMaps,
102                                  MLIRContext *context) {
103   if (indexingMaps.empty())
104     return nullptr;
105   unsigned numIterationDims = indexingMaps.front().getNumDims();
106   unsigned numSymbols = indexingMaps.front().getNumSymbols();
107 
108   // Compute the replacement for each dim expr.
109   SmallVector<AffineExpr, 4> dimReplacements;
110   dimReplacements.reserve(numIterationDims);
111   unsigned numKeptDims = 0;
112   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
113     if (unitDims.count(dim))
114       dimReplacements.push_back(getAffineConstantExpr(0, context));
115     else
116       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
117   }
118 
119   // Symbols remain the same.
120   SmallVector<AffineExpr, 4> symReplacements;
121   symReplacements.reserve(numSymbols);
122   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
123     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
124 
125   SmallVector<AffineMap, 4> newIndexingMaps;
126   newIndexingMaps.reserve(indexingMaps.size());
127   for (AffineMap operandMap : indexingMaps) {
128     // Expected indexing maps to have no symbols.
129     if (operandMap.getNumSymbols())
130       return nullptr;
131     newIndexingMaps.push_back(simplifyAffineMap(
132         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
133                                          numIterationDims - unitDims.size(),
134                                          numSymbols)));
135   }
136 
137   // Check that the new index maps are invertible. If not, something went
138   // wrong, so abort.
139   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
140     return nullptr;
141   return ArrayAttr::get(
142       llvm::to_vector<4>(llvm::map_range(
143           newIndexingMaps,
144           [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
145       context);
146 }
147 
148 namespace {
149 /// Pattern to fold unit-trip count loops in GenericOps.
150 // TODO: Generalize this to indexed-generic as well by modifying the region args
151 // as well.
152 struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
153   using OpRewritePattern<GenericOp>::OpRewritePattern;
154   LogicalResult matchAndRewrite(GenericOp genericOp,
155                                 PatternRewriter &rewriter) const override {
156     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
157     if (indexingMaps.empty())
158       return failure();
159 
160     // Check if any of the iteration dimensions are unit-trip count. They will
161     // end up being unit-trip count if they are used to index into a unit-dim
162     // tensor/memref.
163     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
164     if (!invertedMap)
165       return failure();
166     SmallVector<int64_t, 4> dims;
167     for (ShapedType shapedType : genericOp.getInputOutputShapedTypes())
168       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
169     DenseSet<unsigned> unitDims;
170     ArrayAttr iteratorTypes = genericOp.iterator_types();
171     for (auto expr : enumerate(invertedMap.getResults())) {
172       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
173         if (dims[dimExpr.getPosition()] == 1 &&
174             iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
175                 getParallelIteratorTypeName())
176           unitDims.insert(expr.index());
177     }
178     if (unitDims.empty())
179       return failure();
180 
181     // Compute the modified indexing maps.
182     MLIRContext *context = rewriter.getContext();
183     ArrayAttr newIndexingMapAttr =
184         replaceUnitDims(unitDims, indexingMaps, context);
185     if (!newIndexingMapAttr)
186       return genericOp.emitError("unable to compute modified indexing_maps");
187 
188     // Compute the iterator types of the modified op by dropping the one-trip
189     // count loops.
190     SmallVector<Attribute, 4> newIteratorTypes;
191     for (auto attr : llvm::enumerate(iteratorTypes)) {
192       if (!unitDims.count(attr.index()))
193         newIteratorTypes.push_back(attr.value());
194     }
195 
196     rewriter.startRootUpdate(genericOp);
197     genericOp.indexing_mapsAttr(newIndexingMapAttr);
198     genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
199     rewriter.finalizeRootUpdate(genericOp);
200     return success();
201   }
202 };
203 
204 struct UnitExtentReplacementInfo {
205   RankedTensorType type;
206   AffineMap indexMap;
207   ArrayAttr reassociation;
208 };
209 } // namespace
210 
211 /// Utility function for replacing operands/results to a linalg generic
212 /// operation on tensors with unit-extent dimensions. These can be replaced with
213 /// an operand/result with the unit-extent dimension removed. This is only done
214 /// if the indexing map used to access that didimensionmension has a
215 /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
216 /// Linalg op, and its `indexMap` the utility function returns:
217 /// - the new type with dimensions of size 1 removed.
218 /// - modified index map that can be used to access the replaced result/operand
219 /// - the reassociation that converts from the original tensor type to the
220 ///   modified tensor type.
221 static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
222                                                     RankedTensorType type,
223                                                     MLIRContext *context) {
224   ArrayRef<int64_t> shape = type.getShape();
225   ArrayRef<AffineExpr> exprs = indexMap.getResults();
226   SmallVector<AffineExpr, 2> reassociations;
227   SmallVector<Attribute, 4> reassociationMaps;
228   SmallVector<AffineExpr, 4> newIndexExprs;
229   SmallVector<int64_t, 4> newShape;
230 
231   int64_t origRank = type.getRank();
232   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
233   auto isUnitExtent = [&](int64_t dim) -> bool {
234     return shape[dim] == 1 && exprs[dim] == zeroExpr;
235   };
236 
237   unsigned dim = 0;
238   // Fold dimensions that are unit-extent at the beginning of the tensor.
239   while (dim < origRank && isUnitExtent(dim))
240     reassociations.push_back(getAffineDimExpr(dim++, context));
241   while (dim < origRank) {
242     reassociations.push_back(getAffineDimExpr(dim, context));
243     newIndexExprs.push_back(exprs[dim]);
244     newShape.push_back(shape[dim]);
245     // Fold all following dimensions that are unit-extent.
246     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
247       ++dim;
248       reassociations.push_back(getAffineDimExpr(dim, context));
249     }
250     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
251         origRank, /*numSymbols = */ 0, reassociations, context)));
252     reassociations.clear();
253     ++dim;
254   }
255   UnitExtentReplacementInfo info = {
256       RankedTensorType::get(newShape, type.getElementType()),
257       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
258                      newIndexExprs, context),
259       ArrayAttr::get(reassociationMaps, context)};
260   return info;
261 }
262 
263 namespace {
264 
265 /// Pattern to replace tensors operands/results that are unit extents.
266 struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
267   using OpRewritePattern<GenericOp>::OpRewritePattern;
268   LogicalResult matchAndRewrite(GenericOp genericOp,
269                                 PatternRewriter &rewriter) const override {
270     // TODO: support init_tensors and reductions.
271     if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty())
272       return failure();
273 
274     MLIRContext *context = rewriter.getContext();
275     Location loc = genericOp.getLoc();
276 
277     SmallVector<AffineMap, 4> newIndexingMaps;
278     SmallVector<ArrayAttr, 4> reassociationMaps;
279     SmallVector<ShapedType, 4> newInputOutputTypes;
280     bool doCanonicalization = false;
281     for (auto it : llvm::zip(genericOp.getIndexingMaps(),
282                              genericOp.getInputOutputShapedTypes())) {
283       auto replacementInfo = replaceUnitExtents(
284           std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context);
285       reassociationMaps.push_back(replacementInfo.reassociation);
286       newIndexingMaps.push_back(replacementInfo.indexMap);
287       newInputOutputTypes.push_back(replacementInfo.type);
288       doCanonicalization |= replacementInfo.type != std::get<1>(it);
289     }
290 
291     // If the indexing maps of the result operation are not invertible (i.e. not
292     // legal), abort.
293     if (!doCanonicalization ||
294         !inversePermutation(concatAffineMaps(newIndexingMaps)))
295       return failure();
296 
297     // If any operand type change, insert a reshape to convert from the original
298     // type to the new type.
299     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
300     unsigned flattenedIdx = 0;
301     auto insertReshapes = [&](ValueRange values) {
302       SmallVector<Value, 4> res;
303       res.reserve(values.size());
304       for (auto operand : llvm::enumerate(values)) {
305         if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
306           res.push_back(operand.value());
307         else
308           res.push_back(rewriter.create<linalg::TensorReshapeOp>(
309               loc, newInputOutputTypes[flattenedIdx], operand.value(),
310               reassociationMaps[flattenedIdx]));
311         ++flattenedIdx;
312       }
313       return res;
314     };
315 
316     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
317     SmallVector<Value, 4> newOutputBuffers =
318         insertReshapes(genericOp.output_buffers());
319     SmallVector<Value, 4> newInitTensors =
320         insertReshapes(genericOp.init_tensors());
321 
322     // If any result type change, insert a reshape to convert from the original
323     // type to the new type.
324     SmallVector<Type, 4> resultTypes;
325     resultTypes.reserve(genericOp.getNumResults());
326     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
327       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
328     GenericOp replacementOp = rewriter.create<GenericOp>(
329         loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
330         newIndexingMaps,
331         llvm::to_vector<4>(
332             genericOp.iterator_types().getAsValueRange<StringAttr>()));
333     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
334                                 replacementOp.region().begin());
335 
336     // If any result tensor has a modified shape, then add reshape to recover
337     // the original shape.
338     SmallVector<Value, 4> resultReplacements;
339     for (auto result : llvm::enumerate(replacementOp.getResults())) {
340       unsigned index = result.index() + replacementOp.getNumOperands();
341       RankedTensorType origResultType = genericOp.getResult(result.index())
342                                             .getType()
343                                             .cast<RankedTensorType>();
344       if (origResultType != result.value().getType())
345         resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
346             loc, origResultType, result.value(), reassociationMaps[index]));
347       else
348         resultReplacements.push_back(result.value());
349     }
350     rewriter.replaceOp(genericOp, resultReplacements);
351     return success();
352   }
353 };
354 } // namespace
355 
356 namespace {
357 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
358 /// example:
359 ///
360 ///  %0 = linalg.tensor_reshape %arg0
361 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
362 ///    : tensor<2048xf32> into tensor<1x4x1x512xf32>
363 ///  %1 = linalg.tensor_reshape %0
364 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
365 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
366 ///    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
367 ///
368 /// can be replaced with
369 ///
370 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
371 ///    : tensor<2048xf32> into tensor<4x512xf32>
372 ///
373 /// Similarly,
374 ///
375 ///  %0 = linalg.tensor_reshape %arg0
376 ///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
377 ///     affine_map<(d0, d1, d2, d3) -> (d3)>]
378 ///    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
379 ///  %1 = linalg.tensor_reshape %0
380 ///   [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
381 ///    : tensor<1x4x1x512xf32> into tensor<2048xf32>
382 ///
383 /// can be replaced with
384 ///
385 ///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
386 ///    : tensor<4x512xf32> into tensor<2048xf32>
387 struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
388   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
389 
390   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
391                                 PatternRewriter &rewriter) const override {
392     // Check that the source operand is created from a reshape as well.
393     TensorReshapeOp parentReshapeOp =
394         reshapeOp.src().getDefiningOp<TensorReshapeOp>();
395     if (!parentReshapeOp)
396       return failure();
397 
398     RankedTensorType srcType = reshapeOp.getSrcType(),
399                      dstType = reshapeOp.getResultType(),
400                      parentSrcType = parentReshapeOp.getSrcType();
401     if (!srcType.hasStaticShape() || !dstType.hasStaticShape() ||
402         !parentSrcType.hasStaticShape() ||
403         srcType.getRank() < dstType.getRank() ||
404         parentSrcType.getRank() == dstType.getRank())
405       return failure();
406 
407     // Check if the result tensor_reshape after folding the reshapeOp and
408     // parentReshapeOp are combined.
409     // If the final tensor_reshape is folding, the parentReshapeOp is
410     // introducing unit-dims, and the reshapeOp does an actual reshape.
411     // If the final tensor_reshape op is expanding, the reshapeOp is
412     // introducing unit-dims, and the parentReshapeOp does an actual reshape.
413     bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
414     ArrayRef<int64_t> expandedShape =
415         isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
416     ArrayRef<int64_t> foldedShape =
417         isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
418 
419     unsigned expandedDim = 0, foldedDim = 0;
420     SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
421         foldedShape.size());
422     while (expandedDim < expandedShape.size() &&
423            foldedDim < foldedShape.size()) {
424       int64_t dstSize = foldedShape[foldedDim];
425       int64_t srcSize = expandedShape[expandedDim];
426       while (srcSize < dstSize && expandedDim < expandedShape.size()) {
427         reassociationExprs[foldedDim].push_back(
428             rewriter.getAffineDimExpr(expandedDim++));
429         srcSize *= expandedShape[expandedDim];
430       }
431       if (srcSize == dstSize) {
432         reassociationExprs[foldedDim].push_back(
433             rewriter.getAffineDimExpr(expandedDim++));
434         // If the next dim in foldedShape is not 1, treat subsequent dims in
435         // expandedShape which are 1 to be collapsed.
436         if (foldedDim == foldedShape.size() - 1 ||
437             foldedShape[foldedDim + 1] != 1) {
438           while (expandedDim < expandedShape.size() &&
439                  expandedShape[expandedDim] == 1) {
440             reassociationExprs[foldedDim].push_back(
441                 rewriter.getAffineDimExpr(expandedDim++));
442           }
443         }
444       } else {
445         return failure();
446       }
447       foldedDim++;
448     }
449     if (expandedDim != expandedShape.size())
450       return failure();
451 
452     SmallVector<AffineMap, 4> reassociationMaps =
453         llvm::to_vector<4>(llvm::map_range(
454             reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
455               return AffineMap::get(expandedShape.size(), 0, exprs,
456                                     rewriter.getContext());
457             }));
458     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
459         reshapeOp, dstType, parentReshapeOp.src(),
460         rewriter.getAffineMapArrayAttr(reassociationMaps));
461     return success();
462   }
463 };
464 } // namespace
465 
466 /// Patterns that are used to canonicalize the use of unit-extent dims for
467 /// broadcasting.
468 void mlir::populateLinalgFoldUnitExtentDimsPatterns(
469     MLIRContext *context, OwningRewritePatternList &patterns) {
470   patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
471   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
472   patterns.insert<FoldReshapeOpWithUnitExtent>(context);
473 }
474 
475 namespace {
476 /// Pass that removes unit-extent dims within generic ops.
477 struct LinalgFoldUnitExtentDimsPass
478     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
479   void runOnFunction() override {
480     OwningRewritePatternList patterns;
481     FuncOp funcOp = getFunction();
482     MLIRContext *context = funcOp.getContext();
483     if (foldOneTripLoopsOnly)
484       patterns.insert<FoldUnitDimLoops>(context);
485     else
486       populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
487     applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);
488   }
489 };
490 } // namespace
491 
492 std::unique_ptr<OperationPass<FuncOp>>
493 mlir::createLinalgFoldUnitExtentDimsPass() {
494   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
495 }
496