1307cfdf5SNicolas Vasilache //===- Interchange.cpp - Linalg interchange transformation ----------------===//
2307cfdf5SNicolas Vasilache //
3307cfdf5SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4307cfdf5SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5307cfdf5SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6307cfdf5SNicolas Vasilache //
7307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
8307cfdf5SNicolas Vasilache //
9307cfdf5SNicolas Vasilache // This file implements the linalg interchange transformation.
10307cfdf5SNicolas Vasilache //
11307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
12307cfdf5SNicolas Vasilache
13*eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
14307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
16307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
18f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h"
19307cfdf5SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
21307cfdf5SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
22307cfdf5SNicolas Vasilache #include "mlir/IR/Matchers.h"
23307cfdf5SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
24307cfdf5SNicolas Vasilache #include "mlir/Pass/Pass.h"
25307cfdf5SNicolas Vasilache #include "mlir/Support/LLVM.h"
269a7d111fSNicolas Vasilache #include "llvm/ADT/ScopeExit.h"
27307cfdf5SNicolas Vasilache #include "llvm/Support/Debug.h"
28307cfdf5SNicolas Vasilache #include "llvm/Support/raw_ostream.h"
29307cfdf5SNicolas Vasilache #include <type_traits>
30307cfdf5SNicolas Vasilache
31307cfdf5SNicolas Vasilache #define DEBUG_TYPE "linalg-interchange"
32307cfdf5SNicolas Vasilache
33307cfdf5SNicolas Vasilache using namespace mlir;
34307cfdf5SNicolas Vasilache using namespace mlir::linalg;
35307cfdf5SNicolas Vasilache
369a7d111fSNicolas Vasilache static LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp,ArrayRef<unsigned> interchangeVector)379a7d111fSNicolas Vasilache interchangeGenericOpPrecondition(GenericOp genericOp,
389a7d111fSNicolas Vasilache ArrayRef<unsigned> interchangeVector) {
39495e1d7eSTobias Gysi // Interchange vector must be non-empty and match the number of loops.
40495e1d7eSTobias Gysi if (interchangeVector.empty() ||
4106bb9cf3STobias Gysi genericOp.getNumLoops() != interchangeVector.size())
42307cfdf5SNicolas Vasilache return failure();
43307cfdf5SNicolas Vasilache // Permutation map must be invertible.
4406bb9cf3STobias Gysi if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
4506bb9cf3STobias Gysi genericOp.getContext())))
46307cfdf5SNicolas Vasilache return failure();
47307cfdf5SNicolas Vasilache return success();
48307cfdf5SNicolas Vasilache }
49307cfdf5SNicolas Vasilache
509a7d111fSNicolas Vasilache FailureOr<GenericOp>
interchangeGenericOp(RewriterBase & rewriter,GenericOp genericOp,ArrayRef<unsigned> interchangeVector)519a7d111fSNicolas Vasilache mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
52307cfdf5SNicolas Vasilache ArrayRef<unsigned> interchangeVector) {
539a7d111fSNicolas Vasilache if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
549a7d111fSNicolas Vasilache return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
559a7d111fSNicolas Vasilache
569a7d111fSNicolas Vasilache // 1. Compute the inverse permutation map, it must be non-null since the
579a7d111fSNicolas Vasilache // preconditions are satisfied.
5806bb9cf3STobias Gysi MLIRContext *context = genericOp.getContext();
59495e1d7eSTobias Gysi AffineMap permutationMap = inversePermutation(
60307cfdf5SNicolas Vasilache AffineMap::getPermutationMap(interchangeVector, context));
619a7d111fSNicolas Vasilache assert(permutationMap && "unexpected null map");
629a7d111fSNicolas Vasilache
639a7d111fSNicolas Vasilache // Start a guarded inplace update.
649a7d111fSNicolas Vasilache rewriter.startRootUpdate(genericOp);
659a7d111fSNicolas Vasilache auto guard =
669a7d111fSNicolas Vasilache llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
67495e1d7eSTobias Gysi
68495e1d7eSTobias Gysi // 2. Compute the interchanged indexing maps.
699a7d111fSNicolas Vasilache SmallVector<AffineMap> newIndexingMaps;
707c234ae5STobias Gysi for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
717c234ae5STobias Gysi AffineMap m = genericOp.getTiedIndexingMap(opOperand);
72307cfdf5SNicolas Vasilache if (!permutationMap.isEmpty())
73307cfdf5SNicolas Vasilache m = m.compose(permutationMap);
749a7d111fSNicolas Vasilache newIndexingMaps.push_back(m);
75307cfdf5SNicolas Vasilache }
7606bb9cf3STobias Gysi genericOp->setAttr(getIndexingMapsAttrName(),
779a7d111fSNicolas Vasilache rewriter.getAffineMapArrayAttr(newIndexingMaps));
78495e1d7eSTobias Gysi
79495e1d7eSTobias Gysi // 3. Compute the interchanged iterator types.
8006bb9cf3STobias Gysi ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
819a7d111fSNicolas Vasilache SmallVector<Attribute> itTypesVector;
82495e1d7eSTobias Gysi llvm::append_range(itTypesVector, itTypes);
839072f1b5STobias Gysi SmallVector<int64_t> permutation(interchangeVector.begin(),
849072f1b5STobias Gysi interchangeVector.end());
859072f1b5STobias Gysi applyPermutationToVector(itTypesVector, permutation);
8606bb9cf3STobias Gysi genericOp->setAttr(getIteratorTypesAttrName(),
87c2c83e97STres Popp ArrayAttr::get(context, itTypesVector));
88307cfdf5SNicolas Vasilache
89495e1d7eSTobias Gysi // 4. Transform the index operations by applying the permutation map.
9006bb9cf3STobias Gysi if (genericOp.hasIndexSemantics()) {
91495e1d7eSTobias Gysi OpBuilder::InsertionGuard guard(rewriter);
92495e1d7eSTobias Gysi for (IndexOp indexOp :
93eaa52750STobias Gysi llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
94495e1d7eSTobias Gysi rewriter.setInsertionPoint(indexOp);
95495e1d7eSTobias Gysi SmallVector<Value> allIndices;
9606bb9cf3STobias Gysi allIndices.reserve(genericOp.getNumLoops());
9706bb9cf3STobias Gysi llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
98f31531a3STobias Gysi std::back_inserter(allIndices), [&](uint64_t dim) {
99495e1d7eSTobias Gysi return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
100495e1d7eSTobias Gysi });
101495e1d7eSTobias Gysi rewriter.replaceOpWithNewOp<AffineApplyOp>(
102495e1d7eSTobias Gysi indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
103495e1d7eSTobias Gysi }
104495e1d7eSTobias Gysi }
1059a7d111fSNicolas Vasilache
1069a7d111fSNicolas Vasilache return genericOp;
107307cfdf5SNicolas Vasilache }
108