1307cfdf5SNicolas Vasilache //===- Interchange.cpp - Linalg interchange transformation ----------------===//
2307cfdf5SNicolas Vasilache //
3307cfdf5SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4307cfdf5SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5307cfdf5SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6307cfdf5SNicolas Vasilache //
7307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
8307cfdf5SNicolas Vasilache //
9307cfdf5SNicolas Vasilache // This file implements the linalg interchange transformation.
10307cfdf5SNicolas Vasilache //
11307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
12307cfdf5SNicolas Vasilache 
13*eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
14307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
16307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
18f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h"
19307cfdf5SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
21307cfdf5SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
22307cfdf5SNicolas Vasilache #include "mlir/IR/Matchers.h"
23307cfdf5SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
24307cfdf5SNicolas Vasilache #include "mlir/Pass/Pass.h"
25307cfdf5SNicolas Vasilache #include "mlir/Support/LLVM.h"
269a7d111fSNicolas Vasilache #include "llvm/ADT/ScopeExit.h"
27307cfdf5SNicolas Vasilache #include "llvm/Support/Debug.h"
28307cfdf5SNicolas Vasilache #include "llvm/Support/raw_ostream.h"
29307cfdf5SNicolas Vasilache #include <type_traits>
30307cfdf5SNicolas Vasilache 
31307cfdf5SNicolas Vasilache #define DEBUG_TYPE "linalg-interchange"
32307cfdf5SNicolas Vasilache 
33307cfdf5SNicolas Vasilache using namespace mlir;
34307cfdf5SNicolas Vasilache using namespace mlir::linalg;
35307cfdf5SNicolas Vasilache 
369a7d111fSNicolas Vasilache static LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp,ArrayRef<unsigned> interchangeVector)379a7d111fSNicolas Vasilache interchangeGenericOpPrecondition(GenericOp genericOp,
389a7d111fSNicolas Vasilache                                  ArrayRef<unsigned> interchangeVector) {
39495e1d7eSTobias Gysi   // Interchange vector must be non-empty and match the number of loops.
40495e1d7eSTobias Gysi   if (interchangeVector.empty() ||
4106bb9cf3STobias Gysi       genericOp.getNumLoops() != interchangeVector.size())
42307cfdf5SNicolas Vasilache     return failure();
43307cfdf5SNicolas Vasilache   // Permutation map must be invertible.
4406bb9cf3STobias Gysi   if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
4506bb9cf3STobias Gysi                                                        genericOp.getContext())))
46307cfdf5SNicolas Vasilache     return failure();
47307cfdf5SNicolas Vasilache   return success();
48307cfdf5SNicolas Vasilache }
49307cfdf5SNicolas Vasilache 
509a7d111fSNicolas Vasilache FailureOr<GenericOp>
interchangeGenericOp(RewriterBase & rewriter,GenericOp genericOp,ArrayRef<unsigned> interchangeVector)519a7d111fSNicolas Vasilache mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
52307cfdf5SNicolas Vasilache                                    ArrayRef<unsigned> interchangeVector) {
539a7d111fSNicolas Vasilache   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
549a7d111fSNicolas Vasilache     return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
559a7d111fSNicolas Vasilache 
569a7d111fSNicolas Vasilache   // 1. Compute the inverse permutation map, it must be non-null since the
579a7d111fSNicolas Vasilache   // preconditions are satisfied.
5806bb9cf3STobias Gysi   MLIRContext *context = genericOp.getContext();
59495e1d7eSTobias Gysi   AffineMap permutationMap = inversePermutation(
60307cfdf5SNicolas Vasilache       AffineMap::getPermutationMap(interchangeVector, context));
619a7d111fSNicolas Vasilache   assert(permutationMap && "unexpected null map");
629a7d111fSNicolas Vasilache 
639a7d111fSNicolas Vasilache   // Start a guarded inplace update.
649a7d111fSNicolas Vasilache   rewriter.startRootUpdate(genericOp);
659a7d111fSNicolas Vasilache   auto guard =
669a7d111fSNicolas Vasilache       llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
67495e1d7eSTobias Gysi 
68495e1d7eSTobias Gysi   // 2. Compute the interchanged indexing maps.
699a7d111fSNicolas Vasilache   SmallVector<AffineMap> newIndexingMaps;
707c234ae5STobias Gysi   for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
717c234ae5STobias Gysi     AffineMap m = genericOp.getTiedIndexingMap(opOperand);
72307cfdf5SNicolas Vasilache     if (!permutationMap.isEmpty())
73307cfdf5SNicolas Vasilache       m = m.compose(permutationMap);
749a7d111fSNicolas Vasilache     newIndexingMaps.push_back(m);
75307cfdf5SNicolas Vasilache   }
7606bb9cf3STobias Gysi   genericOp->setAttr(getIndexingMapsAttrName(),
779a7d111fSNicolas Vasilache                      rewriter.getAffineMapArrayAttr(newIndexingMaps));
78495e1d7eSTobias Gysi 
79495e1d7eSTobias Gysi   // 3. Compute the interchanged iterator types.
8006bb9cf3STobias Gysi   ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
819a7d111fSNicolas Vasilache   SmallVector<Attribute> itTypesVector;
82495e1d7eSTobias Gysi   llvm::append_range(itTypesVector, itTypes);
839072f1b5STobias Gysi   SmallVector<int64_t> permutation(interchangeVector.begin(),
849072f1b5STobias Gysi                                    interchangeVector.end());
859072f1b5STobias Gysi   applyPermutationToVector(itTypesVector, permutation);
8606bb9cf3STobias Gysi   genericOp->setAttr(getIteratorTypesAttrName(),
87c2c83e97STres Popp                      ArrayAttr::get(context, itTypesVector));
88307cfdf5SNicolas Vasilache 
89495e1d7eSTobias Gysi   // 4. Transform the index operations by applying the permutation map.
9006bb9cf3STobias Gysi   if (genericOp.hasIndexSemantics()) {
91495e1d7eSTobias Gysi     OpBuilder::InsertionGuard guard(rewriter);
92495e1d7eSTobias Gysi     for (IndexOp indexOp :
93eaa52750STobias Gysi          llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
94495e1d7eSTobias Gysi       rewriter.setInsertionPoint(indexOp);
95495e1d7eSTobias Gysi       SmallVector<Value> allIndices;
9606bb9cf3STobias Gysi       allIndices.reserve(genericOp.getNumLoops());
9706bb9cf3STobias Gysi       llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
98f31531a3STobias Gysi                       std::back_inserter(allIndices), [&](uint64_t dim) {
99495e1d7eSTobias Gysi                         return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
100495e1d7eSTobias Gysi                       });
101495e1d7eSTobias Gysi       rewriter.replaceOpWithNewOp<AffineApplyOp>(
102495e1d7eSTobias Gysi           indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
103495e1d7eSTobias Gysi     }
104495e1d7eSTobias Gysi   }
1059a7d111fSNicolas Vasilache 
1069a7d111fSNicolas Vasilache   return genericOp;
107307cfdf5SNicolas Vasilache }
108