12b0c8546SMaheshRavishankar //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
22b0c8546SMaheshRavishankar //
32b0c8546SMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42b0c8546SMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information.
52b0c8546SMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62b0c8546SMaheshRavishankar //
72b0c8546SMaheshRavishankar //===----------------------------------------------------------------------===//
82b0c8546SMaheshRavishankar //
92b0c8546SMaheshRavishankar // This file implements patterns/pass to remove usage of unit-extent dimensions
102b0c8546SMaheshRavishankar // to specify broadcasting in favor of more canonical representation of the
112b0c8546SMaheshRavishankar // computation
122b0c8546SMaheshRavishankar //
132b0c8546SMaheshRavishankar //===----------------------------------------------------------------------===//
142b0c8546SMaheshRavishankar 
152b0c8546SMaheshRavishankar #include "PassDetail.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
182b0c8546SMaheshRavishankar #include "mlir/Dialect/Linalg/Passes.h"
19ea069aebSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
202b0c8546SMaheshRavishankar #include "mlir/Dialect/Linalg/Utils/Utils.h"
21060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
222b0c8546SMaheshRavishankar #include "mlir/IR/AffineExpr.h"
232b0c8546SMaheshRavishankar #include "mlir/IR/AffineMap.h"
246c7be417STres Popp #include "mlir/IR/BuiltinTypes.h"
252b0c8546SMaheshRavishankar #include "mlir/Transforms/FoldUtils.h"
26b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
272b0c8546SMaheshRavishankar #include "llvm/Support/CommandLine.h"
282b0c8546SMaheshRavishankar #include "llvm/Support/Debug.h"
292b0c8546SMaheshRavishankar 
302b0c8546SMaheshRavishankar #define DEBUG_TYPE "linalg-drop-unit-dims"
312b0c8546SMaheshRavishankar 
322b0c8546SMaheshRavishankar using namespace mlir;
332b0c8546SMaheshRavishankar using namespace mlir::linalg;
342b0c8546SMaheshRavishankar 
352b0c8546SMaheshRavishankar /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
362b0c8546SMaheshRavishankar /// broadcasting. For example,
372b0c8546SMaheshRavishankar ///
382b0c8546SMaheshRavishankar /// ```mlir
392b0c8546SMaheshRavishankar /// #accesses = [
402b0c8546SMaheshRavishankar ///   affine_map<(d0, d1) -> (0, d1)>,
412b0c8546SMaheshRavishankar ///   affine_map<(d0, d1) -> (d0, 0)>,
422b0c8546SMaheshRavishankar ///   affine_map<(d0, d1) -> (d0, d1)>
432b0c8546SMaheshRavishankar /// ]
442b0c8546SMaheshRavishankar ///
452b0c8546SMaheshRavishankar /// #trait = {
462b0c8546SMaheshRavishankar ///   args_in = 2,
472b0c8546SMaheshRavishankar ///   args_out = 1,
482b0c8546SMaheshRavishankar ///   indexing_maps = #accesses,
492b0c8546SMaheshRavishankar ///   iterator_types = ["parallel", "parallel"],
502b0c8546SMaheshRavishankar ///   library_call = "some_external_fn"
512b0c8546SMaheshRavishankar /// }
522b0c8546SMaheshRavishankar ///
532b0c8546SMaheshRavishankar /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
542b0c8546SMaheshRavishankar /// tensor<5x5xf32>
552b0c8546SMaheshRavishankar /// {
562b0c8546SMaheshRavishankar ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
572b0c8546SMaheshRavishankar ///        tensor<5xf32> into tensor<1x5xf32>
582b0c8546SMaheshRavishankar ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
592b0c8546SMaheshRavishankar ///        tensor<5xf32> into tensor<5x1xf32>
602b0c8546SMaheshRavishankar ///   %2 = linalg.generic #trait %0, %1 {
612b0c8546SMaheshRavishankar ///        ^bb0(%arg2: f32, %arg3: f32):
62a54f4eaeSMogball ///          %3 = arith.addf %arg2, %arg3 : f32
632b0c8546SMaheshRavishankar ///          linalg.yield %3 : f32
642b0c8546SMaheshRavishankar ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
652b0c8546SMaheshRavishankar ///   return %2 : tensor<5x5xf32>
662b0c8546SMaheshRavishankar /// }
672b0c8546SMaheshRavishankar ///
682b0c8546SMaheshRavishankar /// would canonicalize to
692b0c8546SMaheshRavishankar ///
702b0c8546SMaheshRavishankar /// ```mlir
712b0c8546SMaheshRavishankar /// #accesses = [
722b0c8546SMaheshRavishankar ///   affine_map<(d0, d1) -> (d1)>,
732b0c8546SMaheshRavishankar ///   affine_map<(d0, d1) -> (d0)>,
742b0c8546SMaheshRavishankar ///   affine_map<(d0, d1) -> (d0, d1)>
752b0c8546SMaheshRavishankar /// ]
762b0c8546SMaheshRavishankar ///
772b0c8546SMaheshRavishankar /// #trait = {
782b0c8546SMaheshRavishankar ///   args_in = 2,
792b0c8546SMaheshRavishankar ///   args_out = 1,
802b0c8546SMaheshRavishankar ///   indexing_maps = #accesses,
812b0c8546SMaheshRavishankar ///   iterator_types = ["parallel", "parallel"],
822b0c8546SMaheshRavishankar ///   library_call = "some_external_fn"
832b0c8546SMaheshRavishankar /// }
842b0c8546SMaheshRavishankar ///
852b0c8546SMaheshRavishankar /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
862b0c8546SMaheshRavishankar /// tensor<5x5xf32>
872b0c8546SMaheshRavishankar /// {
882b0c8546SMaheshRavishankar ///   %0 = linalg.generic #trait %arg0, %arg1 {
892b0c8546SMaheshRavishankar ///        ^bb0(%arg2: f32, %arg3: f32):
90a54f4eaeSMogball ///          %3 = arith.addf %arg2, %arg3 : f32
912b0c8546SMaheshRavishankar ///          linalg.yield %3 : f32
922b0c8546SMaheshRavishankar ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
932b0c8546SMaheshRavishankar ///   return %0 : tensor<5x5xf32>
942b0c8546SMaheshRavishankar /// }
952b0c8546SMaheshRavishankar 
962b0c8546SMaheshRavishankar /// Given dims of the iteration space of a structured op that are known to be
972b0c8546SMaheshRavishankar /// single trip count (`unitDims`), return the indexing maps to use in the
982b0c8546SMaheshRavishankar /// canonicalized op with these dims removed, given the original `indexingMaps`.
replaceUnitDims(DenseSet<unsigned> & unitDims,ArrayRef<AffineMap> indexingMaps,MLIRContext * context)992b0c8546SMaheshRavishankar static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
1002b0c8546SMaheshRavishankar                                  ArrayRef<AffineMap> indexingMaps,
1012b0c8546SMaheshRavishankar                                  MLIRContext *context) {
1022b0c8546SMaheshRavishankar   if (indexingMaps.empty())
1032b0c8546SMaheshRavishankar     return nullptr;
1042b0c8546SMaheshRavishankar   unsigned numIterationDims = indexingMaps.front().getNumDims();
1052b0c8546SMaheshRavishankar   unsigned numSymbols = indexingMaps.front().getNumSymbols();
1062b0c8546SMaheshRavishankar 
1072b0c8546SMaheshRavishankar   // Compute the replacement for each dim expr.
1082b0c8546SMaheshRavishankar   SmallVector<AffineExpr, 4> dimReplacements;
1092b0c8546SMaheshRavishankar   dimReplacements.reserve(numIterationDims);
1102b0c8546SMaheshRavishankar   unsigned numKeptDims = 0;
1112b0c8546SMaheshRavishankar   for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
1122b0c8546SMaheshRavishankar     if (unitDims.count(dim))
1132b0c8546SMaheshRavishankar       dimReplacements.push_back(getAffineConstantExpr(0, context));
1142b0c8546SMaheshRavishankar     else
1152b0c8546SMaheshRavishankar       dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
1162b0c8546SMaheshRavishankar   }
1172b0c8546SMaheshRavishankar 
1182b0c8546SMaheshRavishankar   // Symbols remain the same.
1192b0c8546SMaheshRavishankar   SmallVector<AffineExpr, 4> symReplacements;
1202b0c8546SMaheshRavishankar   symReplacements.reserve(numSymbols);
1212b0c8546SMaheshRavishankar   for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
1222b0c8546SMaheshRavishankar     symReplacements.push_back(getAffineSymbolExpr(symbol, context));
1232b0c8546SMaheshRavishankar 
1242b0c8546SMaheshRavishankar   SmallVector<AffineMap, 4> newIndexingMaps;
1252b0c8546SMaheshRavishankar   newIndexingMaps.reserve(indexingMaps.size());
1262b0c8546SMaheshRavishankar   for (AffineMap operandMap : indexingMaps) {
1272b0c8546SMaheshRavishankar     // Expected indexing maps to have no symbols.
1282b0c8546SMaheshRavishankar     if (operandMap.getNumSymbols())
1292b0c8546SMaheshRavishankar       return nullptr;
1302b0c8546SMaheshRavishankar     newIndexingMaps.push_back(simplifyAffineMap(
1312b0c8546SMaheshRavishankar         operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
1322b0c8546SMaheshRavishankar                                          numIterationDims - unitDims.size(),
1332b0c8546SMaheshRavishankar                                          numSymbols)));
1342b0c8546SMaheshRavishankar   }
1352b0c8546SMaheshRavishankar 
1362b0c8546SMaheshRavishankar   // Check that the new index maps are invertible. If not, something went
1372b0c8546SMaheshRavishankar   // wrong, so abort.
1382b0c8546SMaheshRavishankar   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
1392b0c8546SMaheshRavishankar     return nullptr;
140c2c83e97STres Popp   return ArrayAttr::get(context,
1412b0c8546SMaheshRavishankar                         llvm::to_vector<4>(llvm::map_range(
142c2c83e97STres Popp                             newIndexingMaps, [](AffineMap map) -> Attribute {
143c2c83e97STres Popp                               return AffineMapAttr::get(map);
144c2c83e97STres Popp                             })));
1452b0c8546SMaheshRavishankar }
1462b0c8546SMaheshRavishankar 
147d0774f7fSTobias Gysi /// Update the index accesses of linalg operations having index semantics.
replaceUnitDimIndexOps(GenericOp genericOp,const DenseSet<unsigned> & unitDims,PatternRewriter & rewriter)148f358c372STobias Gysi static void replaceUnitDimIndexOps(GenericOp genericOp,
149d0774f7fSTobias Gysi                                    const DenseSet<unsigned> &unitDims,
150d0774f7fSTobias Gysi                                    PatternRewriter &rewriter) {
151eaa52750STobias Gysi   for (IndexOp indexOp :
152eaa52750STobias Gysi        llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
153d0774f7fSTobias Gysi     OpBuilder::InsertionGuard guard(rewriter);
154d0774f7fSTobias Gysi     rewriter.setInsertionPoint(indexOp);
155d0774f7fSTobias Gysi     if (unitDims.count(indexOp.dim()) != 0) {
156a54f4eaeSMogball       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
157d0774f7fSTobias Gysi     } else {
158d0774f7fSTobias Gysi       // Update the dimension of the index operation if needed.
159d0774f7fSTobias Gysi       unsigned droppedDims = llvm::count_if(
160d0774f7fSTobias Gysi           unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
161d0774f7fSTobias Gysi       if (droppedDims != 0)
162d0774f7fSTobias Gysi         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
163d0774f7fSTobias Gysi                                              indexOp.dim() - droppedDims);
164d0774f7fSTobias Gysi     }
165d0774f7fSTobias Gysi   }
166d0774f7fSTobias Gysi }
167d0774f7fSTobias Gysi 
1682b0c8546SMaheshRavishankar namespace {
1692b0c8546SMaheshRavishankar /// Pattern to fold unit-trip count loops in GenericOps.
170f358c372STobias Gysi struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
171f358c372STobias Gysi   using OpRewritePattern<GenericOp>::OpRewritePattern;
matchAndRewrite__anon1bf028040311::FoldUnitDimLoops172f358c372STobias Gysi   LogicalResult matchAndRewrite(GenericOp genericOp,
1732b0c8546SMaheshRavishankar                                 PatternRewriter &rewriter) const override {
174*d2c0572bSJacques Pienaar     SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMapsArray();
1752b0c8546SMaheshRavishankar     if (indexingMaps.empty())
1762b0c8546SMaheshRavishankar       return failure();
1772b0c8546SMaheshRavishankar 
1782b0c8546SMaheshRavishankar     // Check if any of the iteration dimensions are unit-trip count. They will
1792b0c8546SMaheshRavishankar     // end up being unit-trip count if they are used to index into a unit-dim
1802b0c8546SMaheshRavishankar     // tensor/memref.
1812b0c8546SMaheshRavishankar     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
1822b0c8546SMaheshRavishankar     if (!invertedMap)
1832b0c8546SMaheshRavishankar       return failure();
184c6985052STobias Gysi     SmallVector<int64_t> dims = genericOp.getStaticShape();
185f4eb681dSMaheshRavishankar 
1862b0c8546SMaheshRavishankar     DenseSet<unsigned> unitDims;
187f4eb681dSMaheshRavishankar     SmallVector<unsigned, 4> unitDimsReductionLoops;
188f358c372STobias Gysi     ArrayAttr iteratorTypes = genericOp.iterator_types();
189e4853be2SMehdi Amini     for (const auto &expr : enumerate(invertedMap.getResults())) {
1902b0c8546SMaheshRavishankar       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
19136aac53bSthomasraoux         if (dims[dimExpr.getPosition()] == 1)
1922b0c8546SMaheshRavishankar           unitDims.insert(expr.index());
193f4eb681dSMaheshRavishankar     }
194f4eb681dSMaheshRavishankar 
1952b0c8546SMaheshRavishankar     if (unitDims.empty())
1962b0c8546SMaheshRavishankar       return failure();
1972b0c8546SMaheshRavishankar 
1982b0c8546SMaheshRavishankar     // Compute the modified indexing maps.
1992b0c8546SMaheshRavishankar     MLIRContext *context = rewriter.getContext();
2002b0c8546SMaheshRavishankar     ArrayAttr newIndexingMapAttr =
2012b0c8546SMaheshRavishankar         replaceUnitDims(unitDims, indexingMaps, context);
2022b0c8546SMaheshRavishankar     if (!newIndexingMapAttr)
203f358c372STobias Gysi       return genericOp.emitError("unable to compute modified indexing_maps");
2042b0c8546SMaheshRavishankar 
2052b0c8546SMaheshRavishankar     // Compute the iterator types of the modified op by dropping the one-trip
2062b0c8546SMaheshRavishankar     // count loops.
2072b0c8546SMaheshRavishankar     SmallVector<Attribute, 4> newIteratorTypes;
208e4853be2SMehdi Amini     for (const auto &attr : llvm::enumerate(iteratorTypes)) {
2092b0c8546SMaheshRavishankar       if (!unitDims.count(attr.index()))
2102b0c8546SMaheshRavishankar         newIteratorTypes.push_back(attr.value());
2112b0c8546SMaheshRavishankar     }
2122b0c8546SMaheshRavishankar 
213f358c372STobias Gysi     rewriter.startRootUpdate(genericOp);
214f358c372STobias Gysi     genericOp.indexing_mapsAttr(newIndexingMapAttr);
215f358c372STobias Gysi     genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
216f358c372STobias Gysi     replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
217f358c372STobias Gysi     rewriter.finalizeRootUpdate(genericOp);
2182b0c8546SMaheshRavishankar     return success();
2192b0c8546SMaheshRavishankar   }
2202b0c8546SMaheshRavishankar };
2212b0c8546SMaheshRavishankar 
2222b0c8546SMaheshRavishankar struct UnitExtentReplacementInfo {
223f6b4e081STobias Gysi   Type type;
2242b0c8546SMaheshRavishankar   AffineMap indexMap;
2252b0c8546SMaheshRavishankar   ArrayAttr reassociation;
2262b0c8546SMaheshRavishankar };
2272b0c8546SMaheshRavishankar } // namespace
2282b0c8546SMaheshRavishankar 
2292b0c8546SMaheshRavishankar /// Utility function for replacing operands/results to a linalg generic
2306c7be417STres Popp /// operation with unit-extent dimensions. These can be replaced with
2312b0c8546SMaheshRavishankar /// an operand/result with the unit-extent dimension removed. This is only done
2322b0c8546SMaheshRavishankar /// if the indexing map used to access that didimensionmension has a
2332b0c8546SMaheshRavishankar /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
2342b0c8546SMaheshRavishankar /// Linalg op, and its `indexMap` the utility function returns:
2352b0c8546SMaheshRavishankar /// - the new type with dimensions of size 1 removed.
2362b0c8546SMaheshRavishankar /// - modified index map that can be used to access the replaced result/operand
2372b0c8546SMaheshRavishankar /// - the reassociation that converts from the original tensor type to the
2382b0c8546SMaheshRavishankar ///   modified tensor type.
23944485fcdSTres Popp static llvm::Optional<UnitExtentReplacementInfo>
replaceUnitExtents(GenericOp genericOp,OpOperand * opOperand,MLIRContext * context)24044485fcdSTres Popp replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
2412b0c8546SMaheshRavishankar                    MLIRContext *context) {
242c6985052STobias Gysi   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
243c6985052STobias Gysi   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
244c6985052STobias Gysi   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
245f6b4e081STobias Gysi   SmallVector<AffineExpr> reassociations;
246f6b4e081STobias Gysi   SmallVector<Attribute> reassociationMaps;
247f6b4e081STobias Gysi   SmallVector<AffineExpr> newIndexExprs;
248f6b4e081STobias Gysi   SmallVector<int64_t> newShape;
2492b0c8546SMaheshRavishankar 
250c6985052STobias Gysi   int64_t origRank = genericOp.getRank(opOperand);
2512b0c8546SMaheshRavishankar   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
2522b0c8546SMaheshRavishankar   auto isUnitExtent = [&](int64_t dim) -> bool {
2532b0c8546SMaheshRavishankar     return shape[dim] == 1 && exprs[dim] == zeroExpr;
2542b0c8546SMaheshRavishankar   };
2552b0c8546SMaheshRavishankar 
25644485fcdSTres Popp   // Early return for memrefs with affine maps to represent that we will always
25744485fcdSTres Popp   // leave them unchanged.
25844485fcdSTres Popp   Type actualType = opOperand->get().getType();
25944485fcdSTres Popp   if (auto memref = actualType.dyn_cast<MemRefType>()) {
260e41ebbecSVladislav Vinogradov     if (!memref.getLayout().isIdentity())
26144485fcdSTres Popp       return llvm::None;
26244485fcdSTres Popp   }
26344485fcdSTres Popp 
264f6b4e081STobias Gysi   int64_t dim = 0;
2652b0c8546SMaheshRavishankar   // Fold dimensions that are unit-extent at the beginning of the tensor.
2662b0c8546SMaheshRavishankar   while (dim < origRank && isUnitExtent(dim))
2672b0c8546SMaheshRavishankar     reassociations.push_back(getAffineDimExpr(dim++, context));
2682b0c8546SMaheshRavishankar   while (dim < origRank) {
2692b0c8546SMaheshRavishankar     reassociations.push_back(getAffineDimExpr(dim, context));
2702b0c8546SMaheshRavishankar     newIndexExprs.push_back(exprs[dim]);
2712b0c8546SMaheshRavishankar     newShape.push_back(shape[dim]);
2722b0c8546SMaheshRavishankar     // Fold all following dimensions that are unit-extent.
2732b0c8546SMaheshRavishankar     while (dim + 1 < origRank && isUnitExtent(dim + 1)) {
2742b0c8546SMaheshRavishankar       ++dim;
2752b0c8546SMaheshRavishankar       reassociations.push_back(getAffineDimExpr(dim, context));
2762b0c8546SMaheshRavishankar     }
2772b0c8546SMaheshRavishankar     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
2782865d114SAlexander Belyaev         origRank, /*symbolCount = */ 0, reassociations, context)));
2792b0c8546SMaheshRavishankar     reassociations.clear();
2802b0c8546SMaheshRavishankar     ++dim;
2812b0c8546SMaheshRavishankar   }
28244485fcdSTres Popp 
283f6b4e081STobias Gysi   // Compute the tensor or scalar replacement type.
284046922e1STobias Gysi   Type elementType = getElementTypeOrSelf(opOperand->get());
2856c7be417STres Popp   Type replacementType;
2866c7be417STres Popp   if (elementType == opOperand->get().getType()) {
2876c7be417STres Popp     replacementType = elementType;
2886c7be417STres Popp   } else if (actualType.isa<RankedTensorType>()) {
2896c7be417STres Popp     replacementType = RankedTensorType::get(newShape, elementType);
2906c7be417STres Popp   } else if (actualType.isa<MemRefType>()) {
2916c7be417STres Popp     replacementType = MemRefType::get(newShape, elementType);
2926c7be417STres Popp   }
2936c7be417STres Popp   assert(replacementType && "unsupported shaped type");
294f6b4e081STobias Gysi   UnitExtentReplacementInfo info = {replacementType,
295f6b4e081STobias Gysi                                     AffineMap::get(indexingMap.getNumDims(),
296f6b4e081STobias Gysi                                                    indexingMap.getNumSymbols(),
2972b0c8546SMaheshRavishankar                                                    newIndexExprs, context),
298c2c83e97STres Popp                                     ArrayAttr::get(context, reassociationMaps)};
2992b0c8546SMaheshRavishankar   return info;
3002b0c8546SMaheshRavishankar }
3012b0c8546SMaheshRavishankar 
3022b0c8546SMaheshRavishankar namespace {
303ed229132SNicolas Vasilache 
3042865d114SAlexander Belyaev SmallVector<ReassociationExprs, 2>
convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr)3052865d114SAlexander Belyaev convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
3062865d114SAlexander Belyaev   SmallVector<ReassociationExprs, 2> reassociationExprs;
3072865d114SAlexander Belyaev   for (auto attr : affineMapArrayAttr)
3082865d114SAlexander Belyaev     reassociationExprs.push_back(
3092865d114SAlexander Belyaev         llvm::to_vector<4>(attr.cast<AffineMapAttr>().getValue().getResults()));
3102865d114SAlexander Belyaev   return reassociationExprs;
3112865d114SAlexander Belyaev }
3122865d114SAlexander Belyaev 
3136c7be417STres Popp /// Pattern to replace tensor/buffer operands/results that are unit extents.
3146c7be417STres Popp struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
315f358c372STobias Gysi   using OpRewritePattern<GenericOp>::OpRewritePattern;
3166c7be417STres Popp 
3176c7be417STres Popp   // Return the original value if the type is unchanged, or reshape it. Return a
3186c7be417STres Popp   // nullptr if this is an unsupported type.
maybeExpand__anon1bf028040511::ReplaceUnitExtents3196c7be417STres Popp   Value maybeExpand(Value result, Type origResultType,
3206c7be417STres Popp                     ArrayAttr reassociationMap, Location loc,
3216c7be417STres Popp                     PatternRewriter &rewriter) const {
3226c7be417STres Popp     if (origResultType == result.getType())
3236c7be417STres Popp       return result;
3246c7be417STres Popp     if (origResultType.isa<RankedTensorType>()) {
325b618880eSAlexander Belyaev       return rewriter.create<tensor::ExpandShapeOp>(
3266c7be417STres Popp           loc, origResultType, result,
3276c7be417STres Popp           convertAffineMapArrayToExprs(reassociationMap));
3286c7be417STres Popp     }
3296c7be417STres Popp     if (origResultType.isa<MemRefType>()) {
33046ef86b5SAlexander Belyaev       return rewriter.create<memref::ExpandShapeOp>(
3316c7be417STres Popp           loc, origResultType, result,
3326c7be417STres Popp           convertAffineMapArrayToExprs(reassociationMap));
3336c7be417STres Popp     }
3346c7be417STres Popp     return nullptr;
3356c7be417STres Popp   };
3366c7be417STres Popp 
3376c7be417STres Popp   // Return the original value if the type is unchanged, or reshape it. Return a
3386c7be417STres Popp   // nullptr if this is an unsupported type.
maybeCollapse__anon1bf028040511::ReplaceUnitExtents3396c7be417STres Popp   Value maybeCollapse(Value operand, Type newInputOutputType,
3406c7be417STres Popp                       ArrayAttr reassociationMap, Location loc,
3416c7be417STres Popp                       PatternRewriter &rewriter) const {
3426c7be417STres Popp     auto operandType = operand.getType();
3436c7be417STres Popp     if (operandType == newInputOutputType)
3446c7be417STres Popp       return operand;
3456c7be417STres Popp     if (operandType.isa<MemRefType>()) {
34646ef86b5SAlexander Belyaev       return rewriter.create<memref::CollapseShapeOp>(
3476c7be417STres Popp           loc, newInputOutputType, operand,
3486c7be417STres Popp           convertAffineMapArrayToExprs(reassociationMap));
3496c7be417STres Popp     }
3506c7be417STres Popp     if (operandType.isa<RankedTensorType>()) {
351b618880eSAlexander Belyaev       return rewriter.create<tensor::CollapseShapeOp>(
3526c7be417STres Popp           loc, newInputOutputType, operand,
3536c7be417STres Popp           convertAffineMapArrayToExprs(reassociationMap));
3546c7be417STres Popp     }
3556c7be417STres Popp     return nullptr;
3566c7be417STres Popp   };
3576c7be417STres Popp 
matchAndRewrite__anon1bf028040511::ReplaceUnitExtents358f358c372STobias Gysi   LogicalResult matchAndRewrite(GenericOp genericOp,
3592b0c8546SMaheshRavishankar                                 PatternRewriter &rewriter) const override {
36008f0cb77Sthomasraoux     // Skip the pattern if the op has any tensor with special encoding.
36108f0cb77Sthomasraoux     if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
36208f0cb77Sthomasraoux           auto tensorType = type.dyn_cast<RankedTensorType>();
36308f0cb77Sthomasraoux           return tensorType && tensorType.getEncoding() != nullptr;
36408f0cb77Sthomasraoux         }))
36508f0cb77Sthomasraoux       return failure();
3662b0c8546SMaheshRavishankar     MLIRContext *context = rewriter.getContext();
367f358c372STobias Gysi     Location loc = genericOp.getLoc();
3682b0c8546SMaheshRavishankar 
369f6b4e081STobias Gysi     SmallVector<AffineMap> newIndexingMaps;
370f6b4e081STobias Gysi     SmallVector<ArrayAttr> reassociationMaps;
371f6b4e081STobias Gysi     SmallVector<Type> newInputOutputTypes;
3722b0c8546SMaheshRavishankar     bool doCanonicalization = false;
373c6985052STobias Gysi     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
37444485fcdSTres Popp       auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
37544485fcdSTres Popp       if (replacementInfo) {
37644485fcdSTres Popp         reassociationMaps.push_back(replacementInfo->reassociation);
37744485fcdSTres Popp         newIndexingMaps.push_back(replacementInfo->indexMap);
37844485fcdSTres Popp         newInputOutputTypes.push_back(replacementInfo->type);
37944485fcdSTres Popp         doCanonicalization |=
38044485fcdSTres Popp             replacementInfo->type != opOperand->get().getType();
38144485fcdSTres Popp       } else {
38244485fcdSTres Popp         // If replaceUnitExtents cannot handle this case, maintain the same
38344485fcdSTres Popp         // type, indexing map, and create a set of mappings representing an
38444485fcdSTres Popp         // identity matrix.
38544485fcdSTres Popp         newInputOutputTypes.push_back(opOperand->get().getType());
38644485fcdSTres Popp         newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
38744485fcdSTres Popp         int64_t origRank = genericOp.getRank(opOperand);
38844485fcdSTres Popp         auto maps = llvm::to_vector<8>(llvm::map_range(
38944485fcdSTres Popp             llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
39044485fcdSTres Popp               return AffineMapAttr::get(
39144485fcdSTres Popp                   AffineMap::get(origRank, /*symbolCount = */ 0,
39244485fcdSTres Popp                                  getAffineDimExpr(dim, context), context));
39344485fcdSTres Popp             }));
39444485fcdSTres Popp         reassociationMaps.push_back(ArrayAttr::get(context, maps));
39544485fcdSTres Popp       }
3962b0c8546SMaheshRavishankar     }
3972b0c8546SMaheshRavishankar 
3982b0c8546SMaheshRavishankar     // If the indexing maps of the result operation are not invertible (i.e. not
3992b0c8546SMaheshRavishankar     // legal), abort.
4002b0c8546SMaheshRavishankar     if (!doCanonicalization ||
4012b0c8546SMaheshRavishankar         !inversePermutation(concatAffineMaps(newIndexingMaps)))
4022b0c8546SMaheshRavishankar       return failure();
4032b0c8546SMaheshRavishankar 
4042b0c8546SMaheshRavishankar     // If any operand type change, insert a reshape to convert from the original
4052b0c8546SMaheshRavishankar     // type to the new type.
406ed229132SNicolas Vasilache     // TODO: get rid of flattenedIdx which assumes operand order and contiguity.
407ed229132SNicolas Vasilache     unsigned flattenedIdx = 0;
408ed229132SNicolas Vasilache     auto insertReshapes = [&](ValueRange values) {
409ed229132SNicolas Vasilache       SmallVector<Value, 4> res;
410ed229132SNicolas Vasilache       res.reserve(values.size());
4116c7be417STres Popp       for (auto operand : values) {
4126c7be417STres Popp         auto reshapedValue =
4136c7be417STres Popp             maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
4146c7be417STres Popp                           reassociationMaps[flattenedIdx], loc, rewriter);
4156c7be417STres Popp         assert(reshapedValue &&
4166c7be417STres Popp                "expected ranked MemRef or Tensor operand type");
4176c7be417STres Popp         res.push_back(reshapedValue);
418ed229132SNicolas Vasilache         ++flattenedIdx;
4192b0c8546SMaheshRavishankar       }
420ed229132SNicolas Vasilache       return res;
421ed229132SNicolas Vasilache     };
422ed229132SNicolas Vasilache 
423f358c372STobias Gysi     SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
424f358c372STobias Gysi     SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
4252b0c8546SMaheshRavishankar 
426b7ae1d3dSnicolasvasilache     // If any result type changes, insert a reshape to convert from the original
4272b0c8546SMaheshRavishankar     // type to the new type.
4282b0c8546SMaheshRavishankar     SmallVector<Type, 4> resultTypes;
429f358c372STobias Gysi     resultTypes.reserve(genericOp.getNumResults());
430f358c372STobias Gysi     for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
431f358c372STobias Gysi       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
432f358c372STobias Gysi     GenericOp replacementOp = rewriter.create<GenericOp>(
433b7ae1d3dSnicolasvasilache         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
434ed229132SNicolas Vasilache         llvm::to_vector<4>(
435f358c372STobias Gysi             genericOp.iterator_types().template getAsValueRange<StringAttr>()));
436f358c372STobias Gysi     rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
4372b0c8546SMaheshRavishankar                                 replacementOp.region().begin());
4382b0c8546SMaheshRavishankar 
4392b0c8546SMaheshRavishankar     // If any result tensor has a modified shape, then add reshape to recover
4402b0c8546SMaheshRavishankar     // the original shape.
4412b0c8546SMaheshRavishankar     SmallVector<Value, 4> resultReplacements;
442e4853be2SMehdi Amini     for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
443b7ae1d3dSnicolasvasilache       unsigned index = result.index() + replacementOp.getNumInputs();
4446c7be417STres Popp       auto origResultType = genericOp.getResult(result.index()).getType();
4456c7be417STres Popp 
4466c7be417STres Popp       auto newResult = maybeExpand(result.value(), origResultType,
4476c7be417STres Popp                                    reassociationMaps[index], loc, rewriter);
4486c7be417STres Popp       assert(newResult &&
4496c7be417STres Popp              "unexpected output type other than ranked MemRef or Tensor");
4506c7be417STres Popp       resultReplacements.push_back(newResult);
4512b0c8546SMaheshRavishankar     }
452f358c372STobias Gysi     rewriter.replaceOp(genericOp, resultReplacements);
4532b0c8546SMaheshRavishankar     return success();
4542b0c8546SMaheshRavishankar   }
4552b0c8546SMaheshRavishankar };
456fd15e2b8SMaheshRavishankar } // namespace
457f0a2fe7fSMaheshRavishankar 
458fd15e2b8SMaheshRavishankar namespace {
459060208b4SMatthias Springer /// Convert `extract_slice` operations to rank-reduced versions.
460df5c981bSNicolas Vasilache struct RankReducedExtractSliceOp
461060208b4SMatthias Springer     : public OpRewritePattern<tensor::ExtractSliceOp> {
462060208b4SMatthias Springer   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
463f0a2fe7fSMaheshRavishankar 
matchAndRewrite__anon1bf028040911::RankReducedExtractSliceOp464060208b4SMatthias Springer   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
465f0a2fe7fSMaheshRavishankar                                 PatternRewriter &rewriter) const override {
466060208b4SMatthias Springer     RankedTensorType resultType = sliceOp.getType();
467060208b4SMatthias Springer     SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
468060208b4SMatthias Springer     SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
469060208b4SMatthias Springer     SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
470fd15e2b8SMaheshRavishankar     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
471fd15e2b8SMaheshRavishankar     if (!reassociation ||
472fd15e2b8SMaheshRavishankar         reassociation->size() == static_cast<size_t>(resultType.getRank()))
473f0a2fe7fSMaheshRavishankar       return failure();
474741f8f2bSNicolas Vasilache     auto rankReducedType =
475741f8f2bSNicolas Vasilache         tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
476741f8f2bSNicolas Vasilache             reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
477741f8f2bSNicolas Vasilache             strides)
478fd15e2b8SMaheshRavishankar             .cast<RankedTensorType>();
479f0a2fe7fSMaheshRavishankar 
480060208b4SMatthias Springer     Location loc = sliceOp.getLoc();
481060208b4SMatthias Springer     Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
48204235d07SJacques Pienaar         loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
483b618880eSAlexander Belyaev     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
484b618880eSAlexander Belyaev         sliceOp, resultType, newSlice, *reassociation);
485f0a2fe7fSMaheshRavishankar     return success();
486f0a2fe7fSMaheshRavishankar   }
487f0a2fe7fSMaheshRavishankar };
488f0a2fe7fSMaheshRavishankar 
489060208b4SMatthias Springer /// Convert `insert_slice` operations to rank-reduced versions.
490df5c981bSNicolas Vasilache /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
491df5c981bSNicolas Vasilache template <typename InsertOpTy>
492df5c981bSNicolas Vasilache struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
493df5c981bSNicolas Vasilache   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
494fd15e2b8SMaheshRavishankar 
matchAndRewrite__anon1bf028040911::RankReducedInsertSliceOp495df5c981bSNicolas Vasilache   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
496fd15e2b8SMaheshRavishankar                                 PatternRewriter &rewriter) const override {
497df5c981bSNicolas Vasilache     RankedTensorType sourceType = insertSliceOp.getSourceType();
498df5c981bSNicolas Vasilache     SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
499df5c981bSNicolas Vasilache     SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
500df5c981bSNicolas Vasilache     SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
501fd15e2b8SMaheshRavishankar     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
502fd15e2b8SMaheshRavishankar     if (!reassociation ||
503fd15e2b8SMaheshRavishankar         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
504fd15e2b8SMaheshRavishankar       return failure();
505df5c981bSNicolas Vasilache     Location loc = insertSliceOp.getLoc();
506df5c981bSNicolas Vasilache     tensor::CollapseShapeOp reshapedSource;
507df5c981bSNicolas Vasilache     {
508df5c981bSNicolas Vasilache       OpBuilder::InsertionGuard g(rewriter);
509df5c981bSNicolas Vasilache       // The only difference between InsertSliceOp and ParallelInsertSliceOp is
510df5c981bSNicolas Vasilache       // the the insertion point is just before the ParallelCombiningOp in the
511df5c981bSNicolas Vasilache       // parallel case.
512df5c981bSNicolas Vasilache       if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
513df5c981bSNicolas Vasilache         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
514df5c981bSNicolas Vasilache       reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
515df5c981bSNicolas Vasilache           loc, insertSliceOp.getSource(), *reassociation);
516df5c981bSNicolas Vasilache     }
517df5c981bSNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOpTy>(
518df5c981bSNicolas Vasilache         insertSliceOp, reshapedSource, insertSliceOp.getDest(),
519df5c981bSNicolas Vasilache         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
520df5c981bSNicolas Vasilache         insertSliceOp.getMixedStrides());
521fd15e2b8SMaheshRavishankar     return success();
522fd15e2b8SMaheshRavishankar   }
523fd15e2b8SMaheshRavishankar };
524b62f9f44SMaheshRavishankar } // namespace
525b62f9f44SMaheshRavishankar 
5262b0c8546SMaheshRavishankar /// Patterns that are used to canonicalize the use of unit-extent dims for
5272b0c8546SMaheshRavishankar /// broadcasting.
populateFoldUnitExtentDimsPatterns(RewritePatternSet & patterns)528ea069aebSMaheshRavishankar void mlir::linalg::populateFoldUnitExtentDimsPatterns(
529dc4e913bSChris Lattner     RewritePatternSet &patterns) {
5303a506b31SChris Lattner   auto *context = patterns.getContext();
531df5c981bSNicolas Vasilache   patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
532df5c981bSNicolas Vasilache                RankReducedInsertSliceOp<tensor::InsertSliceOp>,
533df5c981bSNicolas Vasilache                RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
534060208b4SMatthias Springer       context);
535b618880eSAlexander Belyaev   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
536b618880eSAlexander Belyaev   linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context);
537b618880eSAlexander Belyaev   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
538b618880eSAlexander Belyaev   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
5392b0c8546SMaheshRavishankar }
5402b0c8546SMaheshRavishankar 
5412b0c8546SMaheshRavishankar namespace {
5422b0c8546SMaheshRavishankar /// Pass that removes unit-extent dims within generic ops.
5432b0c8546SMaheshRavishankar struct LinalgFoldUnitExtentDimsPass
5442b0c8546SMaheshRavishankar     : public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
runOnOperation__anon1bf028040a11::LinalgFoldUnitExtentDimsPass545c10995a8SStella Laurenzo   void runOnOperation() override {
546c10995a8SStella Laurenzo     Operation *op = getOperation();
547c10995a8SStella Laurenzo     MLIRContext *context = op->getContext();
548dc4e913bSChris Lattner     RewritePatternSet patterns(context);
5492b0c8546SMaheshRavishankar     if (foldOneTripLoopsOnly)
550f358c372STobias Gysi       patterns.add<FoldUnitDimLoops>(context);
5512b0c8546SMaheshRavishankar     else
552ea069aebSMaheshRavishankar       populateFoldUnitExtentDimsPatterns(patterns);
553c10995a8SStella Laurenzo     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
5542b0c8546SMaheshRavishankar   }
5552b0c8546SMaheshRavishankar };
5562b0c8546SMaheshRavishankar } // namespace
5572b0c8546SMaheshRavishankar 
createLinalgFoldUnitExtentDimsPass()558c10995a8SStella Laurenzo std::unique_ptr<Pass> mlir::createLinalgFoldUnitExtentDimsPass() {
5592b0c8546SMaheshRavishankar   return std::make_unique<LinalgFoldUnitExtentDimsPass>();
5602b0c8546SMaheshRavishankar }
561