1307cfdf5SNicolas Vasilache //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2307cfdf5SNicolas Vasilache //
3307cfdf5SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4307cfdf5SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5307cfdf5SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6307cfdf5SNicolas Vasilache //
7307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
8307cfdf5SNicolas Vasilache //
9307cfdf5SNicolas Vasilache // This file implements the linalg dialect Vectorization transformations.
10307cfdf5SNicolas Vasilache //
11307cfdf5SNicolas Vasilache //===----------------------------------------------------------------------===//
12307cfdf5SNicolas Vasilache
13b6113db9SNicolas Vasilache #include "mlir/Analysis/SliceAnalysis.h"
14755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
15eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
171f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
18307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
19b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
20307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Utils/Utils.h"
22060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
23307cfdf5SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2599ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
26307cfdf5SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
27307cfdf5SNicolas Vasilache #include "mlir/IR/Matchers.h"
28307cfdf5SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
29307cfdf5SNicolas Vasilache #include "mlir/Pass/Pass.h"
30307cfdf5SNicolas Vasilache #include "mlir/Support/LLVM.h"
310a2a260aSNicolas Vasilache #include "mlir/Transforms/RegionUtils.h"
320a2a260aSNicolas Vasilache #include "llvm/ADT/ScopeExit.h"
331d49e535SGuillaume Chatelet #include "llvm/ADT/Sequence.h"
341d49e535SGuillaume Chatelet #include "llvm/ADT/SmallVector.h"
35b6113db9SNicolas Vasilache #include "llvm/ADT/TypeSwitch.h"
36307cfdf5SNicolas Vasilache #include "llvm/Support/Debug.h"
37307cfdf5SNicolas Vasilache #include "llvm/Support/raw_ostream.h"
38307cfdf5SNicolas Vasilache #include <type_traits>
39307cfdf5SNicolas Vasilache
40307cfdf5SNicolas Vasilache using namespace mlir;
41307cfdf5SNicolas Vasilache using namespace mlir::linalg;
42307cfdf5SNicolas Vasilache
43307cfdf5SNicolas Vasilache #define DEBUG_TYPE "linalg-vectorization"
44307cfdf5SNicolas Vasilache
45753a67b5SNicolas Vasilache #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
46753a67b5SNicolas Vasilache #define LDBG(X) LLVM_DEBUG(DBGS() << X)
47753a67b5SNicolas Vasilache
48efdd4c16SNicolas Vasilache /// Try to vectorize `convOp` as a convolution.
49efdd4c16SNicolas Vasilache static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
50efdd4c16SNicolas Vasilache LinalgOp convOp);
516bb7d247SNicolas Vasilache
52f245b7adSNicolas Vasilache /// Return the unique instance of OpType in `block` if it is indeed unique.
53f245b7adSNicolas Vasilache /// Return null if none or more than 1 instances exist.
54b6113db9SNicolas Vasilache template <typename OpType>
getSingleOpOfType(Block & block)55b6113db9SNicolas Vasilache static OpType getSingleOpOfType(Block &block) {
56f245b7adSNicolas Vasilache OpType res;
57f245b7adSNicolas Vasilache block.walk([&](OpType op) {
58f245b7adSNicolas Vasilache if (res) {
59f245b7adSNicolas Vasilache res = nullptr;
60f245b7adSNicolas Vasilache return WalkResult::interrupt();
61f245b7adSNicolas Vasilache }
62f245b7adSNicolas Vasilache res = op;
63f245b7adSNicolas Vasilache return WalkResult::advance();
64f245b7adSNicolas Vasilache });
65f245b7adSNicolas Vasilache return res;
66f245b7adSNicolas Vasilache }
67f245b7adSNicolas Vasilache
68b6113db9SNicolas Vasilache /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
69b6113db9SNicolas Vasilache /// projectedPermutation, compress the unused dimensions to serve as a
70b6113db9SNicolas Vasilache /// permutation_map for a vector transfer operation.
71b6113db9SNicolas Vasilache /// For example, given a linalg op such as:
72b6113db9SNicolas Vasilache ///
73b6113db9SNicolas Vasilache /// ```
74b6113db9SNicolas Vasilache /// %0 = linalg.generic {
75b6113db9SNicolas Vasilache /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
76b6113db9SNicolas Vasilache /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
77b6113db9SNicolas Vasilache /// }
78b6113db9SNicolas Vasilache /// ins(%0 : tensor<2x3x4xf32>)
79b6113db9SNicolas Vasilache /// outs(%1 : tensor<5x6xf32>)
80b6113db9SNicolas Vasilache /// ```
81b6113db9SNicolas Vasilache ///
82b6113db9SNicolas Vasilache /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
83b6113db9SNicolas Vasilache /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
84b6113db9SNicolas Vasilache /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
reindexIndexingMap(AffineMap map)85b6113db9SNicolas Vasilache static AffineMap reindexIndexingMap(AffineMap map) {
86ced8690dSMehdi Amini assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
875c1d356cSDiego Caballero "expected projected permutation");
88b6113db9SNicolas Vasilache auto res = compressUnusedDims(map);
89b6113db9SNicolas Vasilache assert(res.getNumDims() == res.getNumResults() &&
90b6113db9SNicolas Vasilache "expected reindexed map with same number of dims and results");
91b6113db9SNicolas Vasilache return res;
92b6113db9SNicolas Vasilache }
93b6113db9SNicolas Vasilache
940a2a260aSNicolas Vasilache /// Helper data structure to represent the result of vectorization.
950a2a260aSNicolas Vasilache /// In certain specific cases, like terminators, we do not want to propagate/
960a2a260aSNicolas Vasilache enum VectorizationStatus {
970a2a260aSNicolas Vasilache /// Op failed to vectorize.
980a2a260aSNicolas Vasilache Failure = 0,
990a2a260aSNicolas Vasilache /// Op vectorized and custom function took care of replacement logic
1000a2a260aSNicolas Vasilache NoReplace,
1010a2a260aSNicolas Vasilache /// Op vectorized into a new Op whose results will replace original Op's
1020a2a260aSNicolas Vasilache /// results.
1030a2a260aSNicolas Vasilache NewOp
1040a2a260aSNicolas Vasilache // TODO: support values if Op vectorized to Many-Ops whose results we need to
1050a2a260aSNicolas Vasilache // aggregate for replacement.
1060a2a260aSNicolas Vasilache };
1070a2a260aSNicolas Vasilache struct VectorizationResult {
1080a2a260aSNicolas Vasilache /// Return status from vectorizing the current op.
1090a2a260aSNicolas Vasilache enum VectorizationStatus status = VectorizationStatus::Failure;
1100a2a260aSNicolas Vasilache /// New vectorized operation to replace the current op.
1110a2a260aSNicolas Vasilache /// Replacement behavior is specified by `status`.
1120a2a260aSNicolas Vasilache Operation *newOp;
1130a2a260aSNicolas Vasilache };
1140a2a260aSNicolas Vasilache
115436d17a8SAlexander Belyaev llvm::Optional<vector::CombiningKind>
getCombinerOpKind(Operation * combinerOp)116436d17a8SAlexander Belyaev mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
117436d17a8SAlexander Belyaev using ::mlir::vector::CombiningKind;
118436d17a8SAlexander Belyaev
119436d17a8SAlexander Belyaev if (!combinerOp)
120eaf2588aSDiego Caballero return llvm::None;
121436d17a8SAlexander Belyaev return llvm::TypeSwitch<Operation *, llvm::Optional<CombiningKind>>(
122436d17a8SAlexander Belyaev combinerOp)
123a54f4eaeSMogball .Case<arith::AddIOp, arith::AddFOp>(
124436d17a8SAlexander Belyaev [&](auto op) { return CombiningKind::ADD; })
125436d17a8SAlexander Belyaev .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
126436d17a8SAlexander Belyaev .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
127436d17a8SAlexander Belyaev .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
128436d17a8SAlexander Belyaev .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
129436d17a8SAlexander Belyaev .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
130a54f4eaeSMogball .Case<arith::MulIOp, arith::MulFOp>(
131436d17a8SAlexander Belyaev [&](auto op) { return CombiningKind::MUL; })
132436d17a8SAlexander Belyaev .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
133436d17a8SAlexander Belyaev .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
134eaf2588aSDiego Caballero .Default([&](auto op) { return llvm::None; });
135eaf2588aSDiego Caballero }
136eaf2588aSDiego Caballero
1372a876a71SDiego Caballero /// Check whether `outputOperand` is a reduction with a single combiner
1387c97e328Sthomasraoux /// operation. Return the combiner operation of the reduction. Return
1397c97e328Sthomasraoux /// nullptr otherwise. Multiple reduction operations would impose an
1407c97e328Sthomasraoux /// ordering between reduction dimensions and is currently unsupported in
1417c97e328Sthomasraoux /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
1422a876a71SDiego Caballero /// max(min(X))
143b6113db9SNicolas Vasilache // TODO: use in LinalgOp verification, there is a circular dependency atm.
matchLinalgReduction(OpOperand * outputOperand)1447c97e328Sthomasraoux static Operation *matchLinalgReduction(OpOperand *outputOperand) {
145912ebf60STobias Gysi auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
1462a876a71SDiego Caballero unsigned outputPos =
147912ebf60STobias Gysi outputOperand->getOperandNumber() - linalgOp.getNumInputs();
1486bb7d247SNicolas Vasilache // Only single combiner operations are supported for now.
1492a876a71SDiego Caballero SmallVector<Operation *, 4> combinerOps;
1502a876a71SDiego Caballero if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
1512a876a71SDiego Caballero combinerOps.size() != 1)
1527c97e328Sthomasraoux return nullptr;
1532a876a71SDiego Caballero
1547c97e328Sthomasraoux // Return the combiner operation.
1557c97e328Sthomasraoux return combinerOps[0];
156b6113db9SNicolas Vasilache }
157b6113db9SNicolas Vasilache
1588f1650cbSNicolas Vasilache /// Broadcast `value` to a vector of `shape` if possible. Return value
1598f1650cbSNicolas Vasilache /// otherwise.
broadcastIfNeeded(OpBuilder & b,Value value,ArrayRef<int64_t> shape)1606825bfe2SNicolas Vasilache static Value broadcastIfNeeded(OpBuilder &b, Value value,
161b6113db9SNicolas Vasilache ArrayRef<int64_t> shape) {
1628f1650cbSNicolas Vasilache // If no shape to broadcast to, just return `value`.
1638f1650cbSNicolas Vasilache if (shape.empty())
164b6113db9SNicolas Vasilache return value;
1658f1650cbSNicolas Vasilache VectorType targetVectorType =
1668f1650cbSNicolas Vasilache VectorType::get(shape, getElementTypeOrSelf(value));
1678f1650cbSNicolas Vasilache if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
1688f1650cbSNicolas Vasilache vector::BroadcastableToResult::Success)
1698f1650cbSNicolas Vasilache return value;
1708f1650cbSNicolas Vasilache Location loc = b.getInsertionPoint()->getLoc();
1718f1650cbSNicolas Vasilache return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
172b6113db9SNicolas Vasilache }
173b6113db9SNicolas Vasilache
1747c97e328Sthomasraoux /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
175afad0cdfSthomasraoux /// assumes that `reductionOp` has two operands and one of them is the reduction
1767c97e328Sthomasraoux /// initial value.
buildMultiDimReduce(OpBuilder & b,Operation * reduceOp,Value valueToReduce,Value acc,const SmallVector<bool> & reductionMask)177051b36baSThomas Raoux static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
178051b36baSThomas Raoux Value valueToReduce, Value acc,
179afad0cdfSthomasraoux const SmallVector<bool> &reductionMask) {
180436d17a8SAlexander Belyaev auto maybeKind = getCombinerOpKind(reduceOp);
1817c97e328Sthomasraoux assert(maybeKind && "Failed precondition: could not get reduction kind");
182afad0cdfSthomasraoux return b.create<vector::MultiDimReductionOp>(
183051b36baSThomas Raoux reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
1847c97e328Sthomasraoux }
1857c97e328Sthomasraoux
getReductionMask(LinalgOp linalgOp)186afad0cdfSthomasraoux static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
187b6113db9SNicolas Vasilache unsigned idx = 0;
188b6113db9SNicolas Vasilache SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
189b6113db9SNicolas Vasilache for (auto attr : linalgOp.iterator_types()) {
190583a7542STobias Gysi if (isReductionIterator(attr))
191b6113db9SNicolas Vasilache reductionMask[idx] = true;
192b6113db9SNicolas Vasilache ++idx;
193b6113db9SNicolas Vasilache }
194afad0cdfSthomasraoux return reductionMask;
1950a2a260aSNicolas Vasilache }
1960a2a260aSNicolas Vasilache
197b6113db9SNicolas Vasilache /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
198b6113db9SNicolas Vasilache /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
199b6113db9SNicolas Vasilache /// currently being vectorized. If `dest` has null rank, build an memref.store.
2000a2a260aSNicolas Vasilache /// Return the produced value or null if no value is produced.
buildVectorWrite(OpBuilder & b,Value value,OpOperand * outputOperand)2016825bfe2SNicolas Vasilache static Value buildVectorWrite(OpBuilder &b, Value value,
202afad0cdfSthomasraoux OpOperand *outputOperand) {
2030a2a260aSNicolas Vasilache Operation *write;
2046825bfe2SNicolas Vasilache Location loc = value.getLoc();
205753a67b5SNicolas Vasilache auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
206c537a943SNicolas Vasilache ArrayRef<int64_t> shape = linalgOp.getShape(outputOperand);
207c537a943SNicolas Vasilache auto vectorType = VectorType::get(
208c537a943SNicolas Vasilache shape, getElementTypeOrSelf(outputOperand->get().getType()));
209c537a943SNicolas Vasilache if (vectorType.getRank() > 0) {
210c537a943SNicolas Vasilache // 0-d case is still special: do not invert the reindexing map.
211912ebf60STobias Gysi AffineMap map =
212912ebf60STobias Gysi reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
21373a9d6d0Sthomasraoux SmallVector<int64_t> transposeShape =
2148c63c24dSBenjamin Kramer applyPermutationMap(inversePermutation(map), vectorType.getShape());
215753a67b5SNicolas Vasilache assert(!transposeShape.empty() && "unexpected empty transpose shape");
21673a9d6d0Sthomasraoux vectorType = VectorType::get(transposeShape, vectorType.getElementType());
217912ebf60STobias Gysi SmallVector<Value> indices(linalgOp.getRank(outputOperand),
218a54f4eaeSMogball b.create<arith::ConstantIndexOp>(loc, 0));
2196825bfe2SNicolas Vasilache value = broadcastIfNeeded(b, value, vectorType.getShape());
220912ebf60STobias Gysi write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
2216825bfe2SNicolas Vasilache indices, map);
2220a2a260aSNicolas Vasilache } else {
223c537a943SNicolas Vasilache if (!value.getType().isa<VectorType>())
224c537a943SNicolas Vasilache value = b.create<vector::BroadcastOp>(loc, vectorType, value);
225c537a943SNicolas Vasilache assert(value.getType() == vectorType && "incorrect type");
226c537a943SNicolas Vasilache write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
227c537a943SNicolas Vasilache ValueRange{});
2280a2a260aSNicolas Vasilache }
229753a67b5SNicolas Vasilache LDBG("vectorized op: " << *write);
2300a2a260aSNicolas Vasilache if (!write->getResults().empty())
2310a2a260aSNicolas Vasilache return write->getResult(0);
2320a2a260aSNicolas Vasilache return Value();
2330a2a260aSNicolas Vasilache }
2340a2a260aSNicolas Vasilache
2350a2a260aSNicolas Vasilache // Custom vectorization function type. Produce a vector form of Operation*
2360a2a260aSNicolas Vasilache // assuming all its vectorized operands are already in the BlockAndValueMapping.
2370a2a260aSNicolas Vasilache // Return nullptr if the Operation cannot be vectorized.
2380a2a260aSNicolas Vasilache using CustomVectorizationHook = std::function<VectorizationResult(
2390a2a260aSNicolas Vasilache Operation *, const BlockAndValueMapping &)>;
2400a2a260aSNicolas Vasilache
2410a2a260aSNicolas Vasilache /// Helper function to vectorize the terminator of a `linalgOp`. New result
242c1a4cd55STobias Gysi /// vector values are appended to `newResults`. Return
243c1a4cd55STobias Gysi /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
244c1a4cd55STobias Gysi /// should not try to map produced operations and instead return the results
245c1a4cd55STobias Gysi /// using the `newResults` vector making them available to the
246c1a4cd55STobias Gysi /// vectorization algorithm for RAUW. This function is meant to be used as a
247c1a4cd55STobias Gysi /// CustomVectorizationHook.
2480a2a260aSNicolas Vasilache static VectorizationResult
vectorizeLinalgYield(OpBuilder & b,Operation * op,const BlockAndValueMapping & bvm,LinalgOp linalgOp,SmallVectorImpl<Value> & newResults)2496825bfe2SNicolas Vasilache vectorizeLinalgYield(OpBuilder &b, Operation *op,
2500a2a260aSNicolas Vasilache const BlockAndValueMapping &bvm, LinalgOp linalgOp,
251c1a4cd55STobias Gysi SmallVectorImpl<Value> &newResults) {
2520a2a260aSNicolas Vasilache auto yieldOp = dyn_cast<linalg::YieldOp>(op);
2530a2a260aSNicolas Vasilache if (!yieldOp)
2540a2a260aSNicolas Vasilache return VectorizationResult{VectorizationStatus::Failure, nullptr};
255e4853be2SMehdi Amini for (const auto &outputs : llvm::enumerate(yieldOp.values())) {
2560a2a260aSNicolas Vasilache // TODO: Scan for an opportunity for reuse.
2570a2a260aSNicolas Vasilache // TODO: use a map.
2580a2a260aSNicolas Vasilache Value vectorValue = bvm.lookup(outputs.value());
259b6113db9SNicolas Vasilache Value newResult = buildVectorWrite(
260afad0cdfSthomasraoux b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
261c1a4cd55STobias Gysi if (newResult)
262c1a4cd55STobias Gysi newResults.push_back(newResult);
2630a2a260aSNicolas Vasilache }
2640a2a260aSNicolas Vasilache return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
265f245b7adSNicolas Vasilache }
2660a2a260aSNicolas Vasilache
267b9715156STobias Gysi /// Helper function to vectorize the index operations of a `linalgOp`. Return
268b9715156STobias Gysi /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
269b9715156STobias Gysi /// should map the produced operations. This function is meant to be used as a
270b9715156STobias Gysi /// CustomVectorizationHook.
vectorizeLinalgIndex(OpBuilder & b,Operation * op,LinalgOp linalgOp)2716825bfe2SNicolas Vasilache static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
2726825bfe2SNicolas Vasilache LinalgOp linalgOp) {
273b9715156STobias Gysi IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
274b9715156STobias Gysi if (!indexOp)
275b9715156STobias Gysi return VectorizationResult{VectorizationStatus::Failure, nullptr};
276b9715156STobias Gysi auto loc = indexOp.getLoc();
277b9715156STobias Gysi // Compute the static loop sizes of the index op.
278b9715156STobias Gysi auto targetShape = linalgOp.computeStaticLoopSizes();
279b9715156STobias Gysi // Compute a one-dimensional index vector for the index op dimension.
280e0569033SGuillaume Chatelet SmallVector<int64_t> constantSeq =
2811d49e535SGuillaume Chatelet llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
282a54f4eaeSMogball auto constantOp =
283a54f4eaeSMogball b.create<arith::ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
284b9715156STobias Gysi // Return the one-dimensional index vector if it lives in the trailing
285b9715156STobias Gysi // dimension of the iteration space since the vectorization algorithm in this
286b9715156STobias Gysi // case can handle the broadcast.
287b9715156STobias Gysi if (indexOp.dim() == targetShape.size() - 1)
288b9715156STobias Gysi return VectorizationResult{VectorizationStatus::NewOp, constantOp};
289b9715156STobias Gysi // Otherwise permute the targetShape to move the index dimension last,
290b9715156STobias Gysi // broadcast the one-dimensional index vector to the permuted shape, and
291b9715156STobias Gysi // finally transpose the broadcasted index vector to undo the permutation.
292b9715156STobias Gysi std::swap(targetShape[indexOp.dim()], targetShape.back());
2936825bfe2SNicolas Vasilache auto broadCastOp = b.create<vector::BroadcastOp>(
2946825bfe2SNicolas Vasilache loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
295e0569033SGuillaume Chatelet SmallVector<int64_t> transposition =
2961d49e535SGuillaume Chatelet llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
297b9715156STobias Gysi std::swap(transposition.back(), transposition[indexOp.dim()]);
298b9715156STobias Gysi auto transposeOp =
2996825bfe2SNicolas Vasilache b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
300b9715156STobias Gysi return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
301b9715156STobias Gysi }
302b9715156STobias Gysi
303afad0cdfSthomasraoux /// Emit reduction operations if the shapes of the value to reduce is different
304afad0cdfSthomasraoux /// that the result shape.
reduceIfNeeded(OpBuilder & b,LinalgOp linalgOp,Operation * op,Value reduceValue,Value initialValue,const BlockAndValueMapping & bvm)305afad0cdfSthomasraoux static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
306afad0cdfSthomasraoux Value reduceValue, Value initialValue,
307afad0cdfSthomasraoux const BlockAndValueMapping &bvm) {
308afad0cdfSthomasraoux Value reduceVec = bvm.lookup(reduceValue);
309afad0cdfSthomasraoux Value outputVec = bvm.lookup(initialValue);
310afad0cdfSthomasraoux auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
311afad0cdfSthomasraoux auto outputType = outputVec.getType().dyn_cast<VectorType>();
312afad0cdfSthomasraoux // Reduce only if needed as the value may already have been reduce for
313afad0cdfSthomasraoux // contraction vectorization.
314afad0cdfSthomasraoux if (!reduceType ||
315afad0cdfSthomasraoux (outputType && reduceType.getShape() == outputType.getShape()))
316afad0cdfSthomasraoux return nullptr;
317afad0cdfSthomasraoux SmallVector<bool> reductionMask = getReductionMask(linalgOp);
318051b36baSThomas Raoux return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
319afad0cdfSthomasraoux }
320afad0cdfSthomasraoux
3210a2a260aSNicolas Vasilache /// Generic vectorization for a single operation `op`, given already vectorized
3220a2a260aSNicolas Vasilache /// operands carried by `bvm`. Vectorization occurs as follows:
3230a2a260aSNicolas Vasilache /// 1. Try to apply any of the `customVectorizationHooks` and return its
3240a2a260aSNicolas Vasilache /// result on success.
3250a2a260aSNicolas Vasilache /// 2. Clone any constant in the current scope without vectorization: each
3260a2a260aSNicolas Vasilache /// consumer of the constant will later determine the shape to which the
3270a2a260aSNicolas Vasilache /// constant needs to be broadcast to.
3280a2a260aSNicolas Vasilache /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
3290a2a260aSNicolas Vasilache /// of the `customVectorizationHooks` to cover such cases.
3300a2a260aSNicolas Vasilache /// 4. Clone `op` in vector form to a vector of shape prescribed by the first
3310a2a260aSNicolas Vasilache /// operand of maximal rank. Other operands have smaller rank and are
3320a2a260aSNicolas Vasilache /// broadcast accordingly. It is assumed this broadcast is always legal,
3330a2a260aSNicolas Vasilache /// otherwise, it means one of the `customVectorizationHooks` is incorrect.
3340a2a260aSNicolas Vasilache ///
3350a2a260aSNicolas Vasilache /// This function assumes all operands of `op` have been vectorized and are in
3360a2a260aSNicolas Vasilache /// the `bvm` mapping. As a consequence, this function is meant to be called on
3370a2a260aSNicolas Vasilache /// a topologically-sorted list of ops.
3380a2a260aSNicolas Vasilache /// This function does not update `bvm` but returns a VectorizationStatus that
3390a2a260aSNicolas Vasilache /// instructs the caller what `bvm` update needs to occur.
3400a2a260aSNicolas Vasilache static VectorizationResult
vectorizeOneOp(OpBuilder & b,LinalgOp linalgOp,Operation * op,const BlockAndValueMapping & bvm,ArrayRef<CustomVectorizationHook> customVectorizationHooks)341afad0cdfSthomasraoux vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
342afad0cdfSthomasraoux const BlockAndValueMapping &bvm,
3430a2a260aSNicolas Vasilache ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
344753a67b5SNicolas Vasilache LDBG("vectorize op " << *op);
3450a2a260aSNicolas Vasilache
3460a2a260aSNicolas Vasilache // 1. Try to apply any CustomVectorizationHook.
3470a2a260aSNicolas Vasilache if (!customVectorizationHooks.empty()) {
3480a2a260aSNicolas Vasilache for (auto &customFunc : customVectorizationHooks) {
3490a2a260aSNicolas Vasilache VectorizationResult result = customFunc(op, bvm);
3500a2a260aSNicolas Vasilache if (result.status == VectorizationStatus::Failure)
3510a2a260aSNicolas Vasilache continue;
3520a2a260aSNicolas Vasilache return result;
3530a2a260aSNicolas Vasilache }
3540a2a260aSNicolas Vasilache }
3550a2a260aSNicolas Vasilache
3560a2a260aSNicolas Vasilache // 2. Constant ops don't get vectorized but rather broadcasted at their users.
3570a2a260aSNicolas Vasilache // Clone so that the constant is not confined to the linalgOp block .
35823aa5a74SRiver Riddle if (isa<arith::ConstantOp, func::ConstantOp>(op))
3596825bfe2SNicolas Vasilache return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)};
3600a2a260aSNicolas Vasilache
3610a2a260aSNicolas Vasilache // 3. Only ElementwiseMappable are allowed in the generic vectorization.
362bcc9b371SFrederik Gossen if (!OpTrait::hasElementwiseMappableTraits(op))
3630a2a260aSNicolas Vasilache return VectorizationResult{VectorizationStatus::Failure, nullptr};
3640a2a260aSNicolas Vasilache
365afad0cdfSthomasraoux // 4 . Check if the operation is a reduction.
366afad0cdfSthomasraoux SmallVector<std::pair<Value, Value>> reductionOperands;
367afad0cdfSthomasraoux for (Value operand : op->getOperands()) {
368afad0cdfSthomasraoux auto arg = operand.dyn_cast<BlockArgument>();
369afad0cdfSthomasraoux if (!arg || arg.getArgNumber() < linalgOp.getNumInputs())
370afad0cdfSthomasraoux continue;
371afad0cdfSthomasraoux SmallVector<Operation *> reductionOps;
372afad0cdfSthomasraoux Value reduceValue = matchReduction(
373afad0cdfSthomasraoux linalgOp.getRegionOutputArgs(),
374afad0cdfSthomasraoux arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps);
375afad0cdfSthomasraoux if (!reduceValue)
376afad0cdfSthomasraoux continue;
377afad0cdfSthomasraoux reductionOperands.push_back(std::make_pair(reduceValue, operand));
378afad0cdfSthomasraoux }
379afad0cdfSthomasraoux if (!reductionOperands.empty()) {
380afad0cdfSthomasraoux assert(reductionOperands.size() == 1);
381afad0cdfSthomasraoux Operation *reduceOp =
382afad0cdfSthomasraoux reduceIfNeeded(b, linalgOp, op, reductionOperands[0].first,
383afad0cdfSthomasraoux reductionOperands[0].second, bvm);
384afad0cdfSthomasraoux if (reduceOp)
385afad0cdfSthomasraoux return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
386afad0cdfSthomasraoux }
387afad0cdfSthomasraoux
388afad0cdfSthomasraoux // 5. Generic vectorization path for ElementwiseMappable ops.
3890a2a260aSNicolas Vasilache // a. first get the first max ranked shape.
3900a2a260aSNicolas Vasilache SmallVector<int64_t, 4> firstMaxRankedShape;
3910a2a260aSNicolas Vasilache for (Value operand : op->getOperands()) {
3920a2a260aSNicolas Vasilache auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
3930a2a260aSNicolas Vasilache if (vt && firstMaxRankedShape.size() < vt.getShape().size())
3940a2a260aSNicolas Vasilache firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
3950a2a260aSNicolas Vasilache }
3960a2a260aSNicolas Vasilache // b. broadcast each op if needed.
3970a2a260aSNicolas Vasilache auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
3980a2a260aSNicolas Vasilache return firstMaxRankedShape.empty()
3990a2a260aSNicolas Vasilache ? bvm.lookup(v)
4006825bfe2SNicolas Vasilache : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape);
4010a2a260aSNicolas Vasilache });
4020a2a260aSNicolas Vasilache // c. for elementwise, the result is the vector with the firstMaxRankedShape
4030a2a260aSNicolas Vasilache auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
4040a2a260aSNicolas Vasilache return firstMaxRankedShape.empty()
4050a2a260aSNicolas Vasilache ? t
4060a2a260aSNicolas Vasilache : VectorType::get(firstMaxRankedShape, t);
4070a2a260aSNicolas Vasilache });
4080a2a260aSNicolas Vasilache
4090a2a260aSNicolas Vasilache // Build and return the new op.
410afad0cdfSthomasraoux return VectorizationResult{
411afad0cdfSthomasraoux VectorizationStatus::NewOp,
41214ecafd0SChia-hung Duan b.create(op->getLoc(), op->getName().getIdentifier(),
41314ecafd0SChia-hung Duan llvm::to_vector<4>(vectorizedOperands),
41414ecafd0SChia-hung Duan llvm::to_vector<4>(returnTypes), op->getAttrs())};
4150a2a260aSNicolas Vasilache }
4160a2a260aSNicolas Vasilache
4170a2a260aSNicolas Vasilache /// Generic vectorization function that rewrites the body of a `linalgOp` into
4180a2a260aSNicolas Vasilache /// vector form. Generic vectorization proceeds as follows:
41988082225STobias Gysi /// 1. Verify the `linalgOp` has one non-empty region.
4200a2a260aSNicolas Vasilache /// 2. Values defined above the region are mapped to themselves and will be
4210a2a260aSNicolas Vasilache /// broadcasted on a per-need basis by their consumers.
4220a2a260aSNicolas Vasilache /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
4230a2a260aSNicolas Vasilache /// load).
4240a2a260aSNicolas Vasilache /// TODO: Reuse opportunities for RAR dependencies.
425b9715156STobias Gysi /// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
426b9715156STobias Gysi /// 4b. Register CustomVectorizationHook for IndexOp to access the iteration
427b9715156STobias Gysi /// indices.
4280a2a260aSNicolas Vasilache /// 5. Iteratively call vectorizeOneOp on the region operations.
429b6113db9SNicolas Vasilache ///
430b6113db9SNicolas Vasilache /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
431b6113db9SNicolas Vasilache /// performed to the maximal common vector size implied by the `linalgOp`
432b6113db9SNicolas Vasilache /// iteration space. This eager broadcasting is introduced in the
433b6113db9SNicolas Vasilache /// permutation_map of the vector.transfer_read operations. The eager
434b6113db9SNicolas Vasilache /// broadcasting makes it trivial to detrmine where broadcast, transposes and
435b6113db9SNicolas Vasilache /// reductions should occur, without any bookkeeping. The tradeoff is that, in
436b6113db9SNicolas Vasilache /// the absence of good canonicalizations, the amount of work increases.
437b6113db9SNicolas Vasilache /// This is not deemed a problem as we expect canonicalizations and foldings to
438b6113db9SNicolas Vasilache /// aggressively clean up the useless work.
43993d0ade1Sthomasraoux static LogicalResult
vectorizeAsLinalgGeneric(OpBuilder & b,LinalgOp linalgOp,SmallVectorImpl<Value> & newResults)44093d0ade1Sthomasraoux vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
44193d0ade1Sthomasraoux SmallVectorImpl<Value> &newResults) {
442eaa52750STobias Gysi Block *block = linalgOp.getBlock();
4430a2a260aSNicolas Vasilache
4440a2a260aSNicolas Vasilache // 2. Values defined above the region can only be broadcast for now. Make them
4450a2a260aSNicolas Vasilache // map to themselves.
446b6113db9SNicolas Vasilache BlockAndValueMapping bvm;
4474efb7754SRiver Riddle SetVector<Value> valuesSet;
44888082225STobias Gysi mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
4490a2a260aSNicolas Vasilache bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
4500a2a260aSNicolas Vasilache
451b6113db9SNicolas Vasilache if (linalgOp.getNumOutputs() == 0)
452b6113db9SNicolas Vasilache return failure();
453b6113db9SNicolas Vasilache
454b6113db9SNicolas Vasilache // TODO: the common vector shape is equal to the static loop sizes only when
455b6113db9SNicolas Vasilache // all indexing maps are projected permutations. For convs and stencils the
456b6113db9SNicolas Vasilache // logic will need to evolve.
457b6113db9SNicolas Vasilache SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
458b6113db9SNicolas Vasilache
4590a2a260aSNicolas Vasilache // 3. Turn all BBArgs into vector.transfer_read / load.
460c537a943SNicolas Vasilache Location loc = linalgOp.getLoc();
461c537a943SNicolas Vasilache Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
462912ebf60STobias Gysi for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
463eaa52750STobias Gysi BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber());
464046922e1STobias Gysi if (linalgOp.isScalar(opOperand)) {
465046922e1STobias Gysi bvm.map(bbarg, opOperand->get());
466046922e1STobias Gysi continue;
467046922e1STobias Gysi }
468c537a943SNicolas Vasilache VectorType readType;
4699621c1efSthomasraoux AffineMap map;
470c537a943SNicolas Vasilache // TODO: can we keep this simplification?
471c537a943SNicolas Vasilache // if (linalgOp.getShape(opOperand).empty()) {
472c537a943SNicolas Vasilache // readType = VectorType::get({}, bbarg.getType());
473c537a943SNicolas Vasilache // } else {
47493d0ade1Sthomasraoux if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
47565bdeddbSOkwan Kwon map = inverseAndBroadcastProjectedPermutation(
476912ebf60STobias Gysi linalgOp.getTiedIndexingMap(opOperand));
477753a67b5SNicolas Vasilache readType = VectorType::get(commonVectorShape,
478046922e1STobias Gysi getElementTypeOrSelf(opOperand->get()));
4799621c1efSthomasraoux } else {
4809621c1efSthomasraoux map = inversePermutation(
481912ebf60STobias Gysi reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
482753a67b5SNicolas Vasilache readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
483046922e1STobias Gysi getElementTypeOrSelf(opOperand->get()));
484b6113db9SNicolas Vasilache }
485c537a943SNicolas Vasilache // }
486c537a943SNicolas Vasilache
487c537a943SNicolas Vasilache auto shape = linalgOp.getShape(opOperand);
488c537a943SNicolas Vasilache SmallVector<Value> indices(shape.size(), zero);
489c537a943SNicolas Vasilache Value readValue = b.create<vector::TransferReadOp>(
490c537a943SNicolas Vasilache loc, readType, opOperand->get(), indices, map);
491c537a943SNicolas Vasilache // Not all ops support 0-d vectors, extract the scalar for now.
492c537a943SNicolas Vasilache // TODO: remove this.
493c537a943SNicolas Vasilache if (readValue.getType().cast<VectorType>().getRank() == 0)
494c537a943SNicolas Vasilache readValue = b.create<vector::ExtractElementOp>(loc, readValue);
495c537a943SNicolas Vasilache
496753a67b5SNicolas Vasilache LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
497753a67b5SNicolas Vasilache bvm.map(bbarg, readValue);
498753a67b5SNicolas Vasilache bvm.map(opOperand->get(), readValue);
4990a2a260aSNicolas Vasilache }
5000a2a260aSNicolas Vasilache
50193d0ade1Sthomasraoux SmallVector<CustomVectorizationHook> hooks;
502b9715156STobias Gysi // 4a. Register CustomVectorizationHook for yieldOp.
5030a2a260aSNicolas Vasilache CustomVectorizationHook vectorizeYield =
5040a2a260aSNicolas Vasilache [&](Operation *op,
5050a2a260aSNicolas Vasilache const BlockAndValueMapping &bvm) -> VectorizationResult {
5066825bfe2SNicolas Vasilache return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults);
5070a2a260aSNicolas Vasilache };
5080a2a260aSNicolas Vasilache hooks.push_back(vectorizeYield);
5090a2a260aSNicolas Vasilache
510b9715156STobias Gysi // 4b. Register CustomVectorizationHook for indexOp.
511b9715156STobias Gysi CustomVectorizationHook vectorizeIndex =
512b9715156STobias Gysi [&](Operation *op,
513b9715156STobias Gysi const BlockAndValueMapping &bvm) -> VectorizationResult {
5146825bfe2SNicolas Vasilache return vectorizeLinalgIndex(b, op, linalgOp);
515b9715156STobias Gysi };
516b9715156STobias Gysi hooks.push_back(vectorizeIndex);
517b9715156STobias Gysi
5180a2a260aSNicolas Vasilache // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
519eaa52750STobias Gysi for (Operation &op : block->getOperations()) {
520afad0cdfSthomasraoux VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
5210a2a260aSNicolas Vasilache if (result.status == VectorizationStatus::Failure) {
522753a67b5SNicolas Vasilache LDBG("failed to vectorize: " << op);
523c1a4cd55STobias Gysi return failure();
5240a2a260aSNicolas Vasilache }
5250a2a260aSNicolas Vasilache if (result.status == VectorizationStatus::NewOp) {
526753a67b5SNicolas Vasilache LDBG("new vector op: " << *result.newOp;);
5270a2a260aSNicolas Vasilache bvm.map(op.getResults(), result.newOp->getResults());
5280a2a260aSNicolas Vasilache }
5290a2a260aSNicolas Vasilache }
5300a2a260aSNicolas Vasilache
531c1a4cd55STobias Gysi return success();
5320a2a260aSNicolas Vasilache }
5330a2a260aSNicolas Vasilache
534b6113db9SNicolas Vasilache // TODO: probably need some extra checks for reduction followed by consumer
535b6113db9SNicolas Vasilache // ops that may not commute (e.g. linear reduction + non-linear instructions).
reductionPreconditions(LinalgOp op)536b6113db9SNicolas Vasilache static LogicalResult reductionPreconditions(LinalgOp op) {
537753a67b5SNicolas Vasilache if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
538753a67b5SNicolas Vasilache LDBG("reduction precondition failed: no reduction iterator");
539b6113db9SNicolas Vasilache return failure();
540753a67b5SNicolas Vasilache }
541912ebf60STobias Gysi for (OpOperand *opOperand : op.getOutputOperands()) {
5427c97e328Sthomasraoux Operation *reduceOp = matchLinalgReduction(opOperand);
543436d17a8SAlexander Belyaev if (!reduceOp || !getCombinerOpKind(reduceOp)) {
544753a67b5SNicolas Vasilache LDBG("reduction precondition failed: reduction detection failed");
545b6113db9SNicolas Vasilache return failure();
546b6113db9SNicolas Vasilache }
547753a67b5SNicolas Vasilache }
548b6113db9SNicolas Vasilache return success();
549b6113db9SNicolas Vasilache }
550b6113db9SNicolas Vasilache
vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op)5519a7d111fSNicolas Vasilache static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
55299154770SAdrian Kuegel // All types in the body should be a supported element type for VectorType.
55399154770SAdrian Kuegel for (Operation &innerOp : op->getRegion(0).front()) {
55499154770SAdrian Kuegel if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
55599154770SAdrian Kuegel return !VectorType::isValidElementType(type);
55699154770SAdrian Kuegel })) {
55799154770SAdrian Kuegel return failure();
55899154770SAdrian Kuegel }
55999154770SAdrian Kuegel if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
56099154770SAdrian Kuegel return !VectorType::isValidElementType(type);
56199154770SAdrian Kuegel })) {
56299154770SAdrian Kuegel return failure();
56399154770SAdrian Kuegel }
56499154770SAdrian Kuegel }
5657d97678dSThomas Raoux if (isElementwise(op))
5667d97678dSThomas Raoux return success();
5677d97678dSThomas Raoux // TODO: isaConvolutionOpInterface that can also infer from generic features.
5687d97678dSThomas Raoux // But we will still need stride/dilation attributes that will be annoying to
5697d97678dSThomas Raoux // reverse-engineer...
5707d97678dSThomas Raoux if (isa<ConvolutionOpInterface>(op.getOperation()))
5717d97678dSThomas Raoux return success();
5727d97678dSThomas Raoux // TODO: the common vector shape is equal to the static loop sizes only when
5737d97678dSThomas Raoux // all indexing maps are projected permutations. For convs and stencils the
5747d97678dSThomas Raoux // logic will need to evolve.
5757d97678dSThomas Raoux if (!allIndexingsAreProjectedPermutation(op)) {
5767d97678dSThomas Raoux LDBG("precondition failed: not projected permutations");
5777d97678dSThomas Raoux return failure();
5787d97678dSThomas Raoux }
5797d97678dSThomas Raoux if (failed(reductionPreconditions(op))) {
5807d97678dSThomas Raoux LDBG("precondition failed: reduction preconditions");
5817d97678dSThomas Raoux return failure();
5827d97678dSThomas Raoux }
5837d97678dSThomas Raoux return success();
5847d97678dSThomas Raoux }
5857d97678dSThomas Raoux
vectorizeLinalgOpPrecondition(LinalgOp linalgOp)5869a7d111fSNicolas Vasilache static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
587e4a503a2SNicolas Vasilache // All types must be static shape to go to vector.
588753a67b5SNicolas Vasilache if (linalgOp.hasDynamicShape()) {
589753a67b5SNicolas Vasilache LDBG("precondition failed: dynamic shape");
590e4a503a2SNicolas Vasilache return failure();
591753a67b5SNicolas Vasilache }
5927d97678dSThomas Raoux return vectorizeStaticLinalgOpPrecondition(linalgOp);
593753a67b5SNicolas Vasilache }
594e4a503a2SNicolas Vasilache
vectorize(RewriterBase & rewriter,LinalgOp linalgOp)5959a7d111fSNicolas Vasilache LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
5969a7d111fSNicolas Vasilache LinalgOp linalgOp) {
5979a7d111fSNicolas Vasilache if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
598c1a4cd55STobias Gysi return failure();
599e4a503a2SNicolas Vasilache
6009a7d111fSNicolas Vasilache SmallVector<Value> results;
6019a7d111fSNicolas Vasilache // TODO: isaConvolutionOpInterface that can also infer from generic
6029a7d111fSNicolas Vasilache // features. Will require stride/dilation attributes inference.
603efdd4c16SNicolas Vasilache FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
604efdd4c16SNicolas Vasilache if (succeeded(convOr)) {
6059a7d111fSNicolas Vasilache llvm::append_range(results, (*convOr)->getResults());
6069a7d111fSNicolas Vasilache } else {
607efdd4c16SNicolas Vasilache if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
608efdd4c16SNicolas Vasilache return failure();
6099a7d111fSNicolas Vasilache LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
6109a7d111fSNicolas Vasilache if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
6119a7d111fSNicolas Vasilache return failure();
6126bb7d247SNicolas Vasilache }
6136bb7d247SNicolas Vasilache
6149a7d111fSNicolas Vasilache if (!results.empty())
6159a7d111fSNicolas Vasilache rewriter.replaceOp(linalgOp, results);
6169a7d111fSNicolas Vasilache else
6179a7d111fSNicolas Vasilache rewriter.eraseOp(linalgOp);
6189a7d111fSNicolas Vasilache
6199a7d111fSNicolas Vasilache return success();
620e4a503a2SNicolas Vasilache }
621e4a503a2SNicolas Vasilache
vectorizeCopy(RewriterBase & rewriter,memref::CopyOp copyOp)622ebc81537SAlexander Belyaev LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
623ebc81537SAlexander Belyaev memref::CopyOp copyOp) {
624ebc81537SAlexander Belyaev
625136d746eSJacques Pienaar auto srcType = copyOp.getSource().getType().cast<MemRefType>();
626136d746eSJacques Pienaar auto dstType = copyOp.getTarget().getType().cast<MemRefType>();
627ebc81537SAlexander Belyaev if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
628ebc81537SAlexander Belyaev return failure();
629ebc81537SAlexander Belyaev
630ebc81537SAlexander Belyaev auto readType =
631ebc81537SAlexander Belyaev VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
632ebc81537SAlexander Belyaev auto writeType =
633ebc81537SAlexander Belyaev VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));
634ebc81537SAlexander Belyaev
635ebc81537SAlexander Belyaev Location loc = copyOp->getLoc();
636ebc81537SAlexander Belyaev Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
637ebc81537SAlexander Belyaev SmallVector<Value> indices(srcType.getRank(), zero);
638ebc81537SAlexander Belyaev
639ebc81537SAlexander Belyaev Value readValue = rewriter.create<vector::TransferReadOp>(
640136d746eSJacques Pienaar loc, readType, copyOp.getSource(), indices,
641ebc81537SAlexander Belyaev rewriter.getMultiDimIdentityMap(srcType.getRank()));
642ebc81537SAlexander Belyaev if (readValue.getType().cast<VectorType>().getRank() == 0) {
643ebc81537SAlexander Belyaev readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
644ebc81537SAlexander Belyaev readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
645ebc81537SAlexander Belyaev }
646ebc81537SAlexander Belyaev Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
647136d746eSJacques Pienaar loc, readValue, copyOp.getTarget(), indices,
648ebc81537SAlexander Belyaev rewriter.getMultiDimIdentityMap(srcType.getRank()));
649ebc81537SAlexander Belyaev rewriter.replaceOp(copyOp, writeValue->getResults());
650ebc81537SAlexander Belyaev return success();
651ebc81537SAlexander Belyaev }
652ebc81537SAlexander Belyaev
653f245b7adSNicolas Vasilache //----------------------------------------------------------------------------//
654bb69de3fSNicolas Vasilache // Misc. vectorization patterns.
655f245b7adSNicolas Vasilache //----------------------------------------------------------------------------//
656bb69de3fSNicolas Vasilache
657b1fd8a13SMatthias Springer /// Helper function that retrieves the value of an IntegerAttr.
getIntFromAttr(Attribute attr)658b1fd8a13SMatthias Springer static int64_t getIntFromAttr(Attribute attr) {
659b1fd8a13SMatthias Springer return attr.cast<IntegerAttr>().getInt();
660b1fd8a13SMatthias Springer }
661b1fd8a13SMatthias Springer
6629a7d111fSNicolas Vasilache /// Given an ArrayRef of OpFoldResults, return a vector of Values.
6639a7d111fSNicolas Vasilache /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
6649a7d111fSNicolas Vasilache /// not supported.
ofrToIndexValues(OpBuilder & builder,Location loc,ArrayRef<OpFoldResult> ofrs)665b1fd8a13SMatthias Springer static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
666b1fd8a13SMatthias Springer ArrayRef<OpFoldResult> ofrs) {
667b1fd8a13SMatthias Springer SmallVector<Value> result;
668*c730f9a1SKazu Hirata for (auto o : ofrs) {
669b1fd8a13SMatthias Springer if (auto val = o.template dyn_cast<Value>()) {
670b1fd8a13SMatthias Springer result.push_back(val);
671b1fd8a13SMatthias Springer } else {
672a54f4eaeSMogball result.push_back(builder.create<arith::ConstantIndexOp>(
673b1fd8a13SMatthias Springer loc, getIntFromAttr(o.template get<Attribute>())));
674b1fd8a13SMatthias Springer }
675*c730f9a1SKazu Hirata }
676b1fd8a13SMatthias Springer return result;
677b1fd8a13SMatthias Springer }
678b1fd8a13SMatthias Springer
679fd0c6f53SAlexander Belyaev /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp and
680060208b4SMatthias Springer /// InsertSliceOp. For now, only constant padding values are supported.
6814c2f3d81SMatthias Springer /// If there is enough static type information, TransferReadOps and
682060208b4SMatthias Springer /// TransferWriteOps may be generated instead of InsertSliceOps.
683fd0c6f53SAlexander Belyaev struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
GenericPadOpVectorizationPatternGenericPadOpVectorizationPattern684fd0c6f53SAlexander Belyaev GenericPadOpVectorizationPattern(MLIRContext *context,
68535df2f6fSYi Zhang PatternBenefit benefit = 1)
686fd0c6f53SAlexander Belyaev : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
687fd0c6f53SAlexander Belyaev /// Vectorize the copying of a tensor::PadOp's source. This is possible if
6889a7d111fSNicolas Vasilache /// each dimension size is statically know in the source type or the result
6899a7d111fSNicolas Vasilache /// type (or both).
tryVectorizeCopyGenericPadOpVectorizationPattern69035df2f6fSYi Zhang static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
691fd0c6f53SAlexander Belyaev tensor::PadOp padOp, Value dest) {
69201e3b344SMatthias Springer auto sourceType = padOp.getSourceType();
69301e3b344SMatthias Springer auto resultType = padOp.getResultType();
69401e3b344SMatthias Springer
695ddda52ceSMatthias Springer // Copy cannot be vectorized if pad value is non-constant and source shape
696ddda52ceSMatthias Springer // is dynamic. In case of a dynamic source shape, padding must be appended
697ddda52ceSMatthias Springer // by TransferReadOp, but TransferReadOp supports only constant padding.
698ddda52ceSMatthias Springer auto padValue = padOp.getConstantPaddingValue();
699ddda52ceSMatthias Springer if (!padValue) {
7008f1650cbSNicolas Vasilache if (!sourceType.hasStaticShape())
7018f1650cbSNicolas Vasilache return failure();
702ddda52ceSMatthias Springer // Create dummy padding value.
703ddda52ceSMatthias Springer auto elemType = sourceType.getElementType();
704a54f4eaeSMogball padValue = rewriter.create<arith::ConstantOp>(
705a54f4eaeSMogball padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
706ddda52ceSMatthias Springer }
707ddda52ceSMatthias Springer
70801e3b344SMatthias Springer SmallVector<int64_t> vecShape;
70901e3b344SMatthias Springer SmallVector<bool> readInBounds;
71001e3b344SMatthias Springer SmallVector<bool> writeInBounds;
71101e3b344SMatthias Springer for (unsigned i = 0; i < sourceType.getRank(); ++i) {
71201e3b344SMatthias Springer if (!sourceType.isDynamicDim(i)) {
71301e3b344SMatthias Springer vecShape.push_back(sourceType.getDimSize(i));
7149a7d111fSNicolas Vasilache // Source shape is statically known: Neither read nor write are
7159a7d111fSNicolas Vasilache // out-of- bounds.
71601e3b344SMatthias Springer readInBounds.push_back(true);
71701e3b344SMatthias Springer writeInBounds.push_back(true);
71801e3b344SMatthias Springer } else if (!resultType.isDynamicDim(i)) {
7199a7d111fSNicolas Vasilache // Source shape is not statically known, but result shape is.
7209a7d111fSNicolas Vasilache // Vectorize with size of result shape. This may be larger than the
7219a7d111fSNicolas Vasilache // source size.
72201e3b344SMatthias Springer vecShape.push_back(resultType.getDimSize(i));
72301e3b344SMatthias Springer // Read may be out-of-bounds because the result size could be larger
72401e3b344SMatthias Springer // than the source size.
72501e3b344SMatthias Springer readInBounds.push_back(false);
72601e3b344SMatthias Springer // Write is out-of-bounds if low padding > 0.
72701e3b344SMatthias Springer writeInBounds.push_back(
7280813700dSMatthias Springer getConstantIntValue(padOp.getMixedLowPad()[i]) ==
7290813700dSMatthias Springer static_cast<int64_t>(0));
73001e3b344SMatthias Springer } else {
73101e3b344SMatthias Springer // Neither source nor result dim of padOp is static. Cannot vectorize
73201e3b344SMatthias Springer // the copy.
73301e3b344SMatthias Springer return failure();
73401e3b344SMatthias Springer }
73501e3b344SMatthias Springer }
73601e3b344SMatthias Springer auto vecType = VectorType::get(vecShape, sourceType.getElementType());
73701e3b344SMatthias Springer
7384c2f3d81SMatthias Springer // Generate TransferReadOp.
7394c2f3d81SMatthias Springer SmallVector<Value> readIndices(
740a54f4eaeSMogball vecType.getRank(),
741a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
7424c2f3d81SMatthias Springer auto read = rewriter.create<vector::TransferReadOp>(
74304235d07SJacques Pienaar padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
744c537a943SNicolas Vasilache ArrayRef<bool>{readInBounds});
7454c2f3d81SMatthias Springer
7469a7d111fSNicolas Vasilache // If `dest` is a FillOp and the TransferWriteOp would overwrite the
7479a7d111fSNicolas Vasilache // entire tensor, write directly to the FillOp's operand.
7488f1650cbSNicolas Vasilache if (llvm::equal(vecShape, resultType.getShape()) &&
7498f1650cbSNicolas Vasilache llvm::all_of(writeInBounds, [](bool b) { return b; }))
7509adc0114SMatthias Springer if (auto fill = dest.getDefiningOp<FillOp>())
7519adc0114SMatthias Springer dest = fill.output();
7529adc0114SMatthias Springer
75301e3b344SMatthias Springer // Generate TransferWriteOp.
7548f1650cbSNicolas Vasilache auto writeIndices =
7558f1650cbSNicolas Vasilache ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
7564c2f3d81SMatthias Springer rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
757c537a943SNicolas Vasilache padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
75801e3b344SMatthias Springer
75901e3b344SMatthias Springer return success();
7604c2f3d81SMatthias Springer }
76198fff515SMatthias Springer };
76298fff515SMatthias Springer
763fd0c6f53SAlexander Belyaev /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
7649a7d111fSNicolas Vasilache /// given operation type OpTy.
765b1b82271SMatthias Springer template <typename OpTy>
766fd0c6f53SAlexander Belyaev struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
767fd0c6f53SAlexander Belyaev using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
768b1b82271SMatthias Springer
matchAndRewriteVectorizePadOpUserPattern769fd0c6f53SAlexander Belyaev LogicalResult matchAndRewrite(tensor::PadOp padOp,
770b1b82271SMatthias Springer PatternRewriter &rewriter) const final {
771b1b82271SMatthias Springer bool changed = false;
772b1b82271SMatthias Springer // Insert users in vector, because some users may be replaced/removed.
773b1b82271SMatthias Springer for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
774b1b82271SMatthias Springer if (auto op = dyn_cast<OpTy>(user))
775b1b82271SMatthias Springer changed |= rewriteUser(rewriter, padOp, op).succeeded();
776b1b82271SMatthias Springer return success(changed);
777b1b82271SMatthias Springer }
778b1b82271SMatthias Springer
779b1b82271SMatthias Springer protected:
7808f1650cbSNicolas Vasilache virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
781fd0c6f53SAlexander Belyaev tensor::PadOp padOp, OpTy op) const = 0;
782b1b82271SMatthias Springer };
783b1b82271SMatthias Springer
784fd0c6f53SAlexander Belyaev /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
785b1b82271SMatthias Springer /// ```
7861ad9b266Slorenzo chelini /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
787b1b82271SMatthias Springer /// %r = vector.transfer_read %0[%c0, %c0], %cst
788b1b82271SMatthias Springer /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
789b1b82271SMatthias Springer /// ```
790b1b82271SMatthias Springer /// is rewritten to:
791b1b82271SMatthias Springer /// ```
792b1b82271SMatthias Springer /// %r = vector.transfer_read %src[%c0, %c0], %padding
793b1b82271SMatthias Springer /// {in_bounds = [true, true]}
794b1b82271SMatthias Springer /// : tensor<?x?xf32>, vector<17x5xf32>
795b1b82271SMatthias Springer /// ```
796b1b82271SMatthias Springer /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
797b1b82271SMatthias Springer /// sure that the original padding value %cst was never used.
798b1b82271SMatthias Springer ///
799b1b82271SMatthias Springer /// This rewrite is possible if:
800b1b82271SMatthias Springer /// - `xferOp` has no out-of-bounds dims or mask.
801b1b82271SMatthias Springer /// - Low padding is static 0.
802b1b82271SMatthias Springer /// - Single, scalar padding value.
803fd0c6f53SAlexander Belyaev struct PadOpVectorizationWithTransferReadPattern
804fd0c6f53SAlexander Belyaev : public VectorizePadOpUserPattern<vector::TransferReadOp> {
805fd0c6f53SAlexander Belyaev using VectorizePadOpUserPattern<
806fd0c6f53SAlexander Belyaev vector::TransferReadOp>::VectorizePadOpUserPattern;
807b1b82271SMatthias Springer
rewriteUserPadOpVectorizationWithTransferReadPattern808fd0c6f53SAlexander Belyaev LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
809b1b82271SMatthias Springer vector::TransferReadOp xferOp) const override {
810b1b82271SMatthias Springer // Low padding must be static 0.
8118f1650cbSNicolas Vasilache if (!padOp.hasZeroLowPad())
8128f1650cbSNicolas Vasilache return failure();
813b1b82271SMatthias Springer // Pad value must be a constant.
814b1b82271SMatthias Springer auto padValue = padOp.getConstantPaddingValue();
8158f1650cbSNicolas Vasilache if (!padValue)
8168f1650cbSNicolas Vasilache return failure();
817b1b82271SMatthias Springer // Padding value of existing `xferOp` is unused.
8187c38fd60SJacques Pienaar if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
8198f1650cbSNicolas Vasilache return failure();
820b1b82271SMatthias Springer
821b1b82271SMatthias Springer rewriter.updateRootInPlace(xferOp, [&]() {
822b1b82271SMatthias Springer SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
823b1b82271SMatthias Springer xferOp->setAttr(xferOp.getInBoundsAttrName(),
824b1b82271SMatthias Springer rewriter.getBoolArrayAttr(inBounds));
82504235d07SJacques Pienaar xferOp.getSourceMutable().assign(padOp.getSource());
8267c38fd60SJacques Pienaar xferOp.getPaddingMutable().assign(padValue);
827b1b82271SMatthias Springer });
828b1b82271SMatthias Springer
829b1b82271SMatthias Springer return success();
830b1b82271SMatthias Springer }
831b1b82271SMatthias Springer };
832b1b82271SMatthias Springer
833fd0c6f53SAlexander Belyaev /// Rewrite use of tensor::PadOp result in TransferWriteOp.
8349a7d111fSNicolas Vasilache /// This pattern rewrites TransferWriteOps that write to a padded tensor
8359a7d111fSNicolas Vasilache /// value, where the same amount of padding is immediately removed again after
8369a7d111fSNicolas Vasilache /// the write. In such cases, the TransferWriteOp can write to the non-padded
8379a7d111fSNicolas Vasilache /// tensor value and apply out-of-bounds masking. E.g.:
838562f9e99SMatthias Springer /// ```
839060208b4SMatthias Springer /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
840060208b4SMatthias Springer /// : tensor<...> to tensor<?x?xf32>
8411ad9b266Slorenzo chelini /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
842562f9e99SMatthias Springer /// %2 = vector.transfer_write %vec, %1[...]
843562f9e99SMatthias Springer /// : vector<17x5xf32>, tensor<17x5xf32>
844060208b4SMatthias Springer /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
845562f9e99SMatthias Springer /// : tensor<17x5xf32> to tensor<?x?xf32>
846562f9e99SMatthias Springer /// ```
847562f9e99SMatthias Springer /// is rewritten to:
848562f9e99SMatthias Springer /// ```
849060208b4SMatthias Springer /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
850060208b4SMatthias Springer /// : tensor<...> to tensor<?x?xf32>
8519a7d111fSNicolas Vasilache /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
8529a7d111fSNicolas Vasilache /// tensor<?x?xf32>
853562f9e99SMatthias Springer /// ```
854060208b4SMatthias Springer /// Note: It is important that the ExtractSliceOp %r resizes the result of the
8559a7d111fSNicolas Vasilache /// TransferWriteOp to the same size as the input of the TensorPadOp (or an
8569a7d111fSNicolas Vasilache /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
8579a7d111fSNicolas Vasilache /// from %r's old dimensions.
858562f9e99SMatthias Springer ///
859562f9e99SMatthias Springer /// This rewrite is possible if:
860562f9e99SMatthias Springer /// - Low padding is static 0.
861060208b4SMatthias Springer /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
8629a7d111fSNicolas Vasilache /// ExtractSliceOp trims the same amount of padding that was added
8639a7d111fSNicolas Vasilache /// beforehand.
864562f9e99SMatthias Springer /// - Single, scalar padding value.
865fd0c6f53SAlexander Belyaev struct PadOpVectorizationWithTransferWritePattern
866fd0c6f53SAlexander Belyaev : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
867fd0c6f53SAlexander Belyaev using VectorizePadOpUserPattern<
868fd0c6f53SAlexander Belyaev vector::TransferWriteOp>::VectorizePadOpUserPattern;
869562f9e99SMatthias Springer
rewriteUserPadOpVectorizationWithTransferWritePattern870fd0c6f53SAlexander Belyaev LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
871562f9e99SMatthias Springer vector::TransferWriteOp xferOp) const override {
872c537a943SNicolas Vasilache // TODO: support 0-d corner case.
873c537a943SNicolas Vasilache if (xferOp.getTransferRank() == 0)
874c537a943SNicolas Vasilache return failure();
875c537a943SNicolas Vasilache
876562f9e99SMatthias Springer // Low padding must be static 0.
8778f1650cbSNicolas Vasilache if (!padOp.hasZeroLowPad())
8788f1650cbSNicolas Vasilache return failure();
879562f9e99SMatthias Springer // Pad value must be a constant.
880562f9e99SMatthias Springer auto padValue = padOp.getConstantPaddingValue();
8818f1650cbSNicolas Vasilache if (!padValue)
8828f1650cbSNicolas Vasilache return failure();
883060208b4SMatthias Springer // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
8848f1650cbSNicolas Vasilache if (!xferOp->hasOneUse())
8858f1650cbSNicolas Vasilache return failure();
886060208b4SMatthias Springer auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
8878f1650cbSNicolas Vasilache if (!trimPadding)
8888f1650cbSNicolas Vasilache return failure();
889562f9e99SMatthias Springer // Only static zero offsets supported when trimming padding.
8908f1650cbSNicolas Vasilache if (!trimPadding.hasZeroOffset())
8918f1650cbSNicolas Vasilache return failure();
892562f9e99SMatthias Springer // trimPadding must remove the amount of padding that was added earlier.
89304235d07SJacques Pienaar if (!hasSameTensorSize(padOp.getSource(), trimPadding))
8948f1650cbSNicolas Vasilache return failure();
895562f9e99SMatthias Springer
8963f8f2923STobias Gysi // Insert the new TransferWriteOp at position of the old TransferWriteOp.
8973f8f2923STobias Gysi rewriter.setInsertionPoint(xferOp);
8983f8f2923STobias Gysi
899562f9e99SMatthias Springer SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
900562f9e99SMatthias Springer auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
90104235d07SJacques Pienaar xferOp, padOp.getSource().getType(), xferOp.getVector(),
90204235d07SJacques Pienaar padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
90304235d07SJacques Pienaar xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
904562f9e99SMatthias Springer rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
905562f9e99SMatthias Springer
906562f9e99SMatthias Springer return success();
907562f9e99SMatthias Springer }
908562f9e99SMatthias Springer
909562f9e99SMatthias Springer /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
910562f9e99SMatthias Springer /// i.e., same dimensions.
911562f9e99SMatthias Springer ///
912562f9e99SMatthias Springer /// Dimensions may be static, dynamic or mix of both. In case of dynamic
913562f9e99SMatthias Springer /// dimensions, this function tries to infer the (static) tensor size by
914562f9e99SMatthias Springer /// looking at the defining op and utilizing op-specific knowledge.
915562f9e99SMatthias Springer ///
916562f9e99SMatthias Springer /// This is a conservative analysis. In case equal tensor sizes cannot be
917562f9e99SMatthias Springer /// proven statically, this analysis returns `false` even though the tensor
918562f9e99SMatthias Springer /// sizes may turn out to be equal at runtime.
hasSameTensorSizePadOpVectorizationWithTransferWritePattern919060208b4SMatthias Springer bool hasSameTensorSize(Value beforePadding,
920060208b4SMatthias Springer tensor::ExtractSliceOp afterTrimming) const {
921fd0c6f53SAlexander Belyaev // If the input to tensor::PadOp is a CastOp, try with with both CastOp
9229a7d111fSNicolas Vasilache // result and CastOp operand.
923562f9e99SMatthias Springer if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
92404235d07SJacques Pienaar if (hasSameTensorSize(castOp.getSource(), afterTrimming))
9258f1650cbSNicolas Vasilache return true;
926562f9e99SMatthias Springer
927562f9e99SMatthias Springer auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
928562f9e99SMatthias Springer auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
929562f9e99SMatthias Springer // Only RankedTensorType supported.
9308f1650cbSNicolas Vasilache if (!t1 || !t2)
9318f1650cbSNicolas Vasilache return false;
932562f9e99SMatthias Springer // Rank of both values must be the same.
9338f1650cbSNicolas Vasilache if (t1.getRank() != t2.getRank())
9348f1650cbSNicolas Vasilache return false;
935562f9e99SMatthias Springer
936562f9e99SMatthias Springer // All static dimensions must be the same. Mixed cases (e.g., dimension
937562f9e99SMatthias Springer // static in `t1` but dynamic in `t2`) are not supported.
938562f9e99SMatthias Springer for (unsigned i = 0; i < t1.getRank(); ++i) {
939562f9e99SMatthias Springer if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
940562f9e99SMatthias Springer return false;
941562f9e99SMatthias Springer if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
942562f9e99SMatthias Springer return false;
943562f9e99SMatthias Springer }
944562f9e99SMatthias Springer
945562f9e99SMatthias Springer // Nothing more to check if all dimensions are static.
9468f1650cbSNicolas Vasilache if (t1.getNumDynamicDims() == 0)
9478f1650cbSNicolas Vasilache return true;
948562f9e99SMatthias Springer
9499a7d111fSNicolas Vasilache // All dynamic sizes must be the same. The only supported case at the
9509a7d111fSNicolas Vasilache // moment is when `beforePadding` is an ExtractSliceOp (or a cast
9519a7d111fSNicolas Vasilache // thereof).
952562f9e99SMatthias Springer
953060208b4SMatthias Springer // Apart from CastOp, only ExtractSliceOp is supported.
954060208b4SMatthias Springer auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
955060208b4SMatthias Springer if (!beforeSlice)
956060208b4SMatthias Springer return false;
957562f9e99SMatthias Springer
958060208b4SMatthias Springer assert(static_cast<size_t>(t1.getRank()) ==
959060208b4SMatthias Springer beforeSlice.getMixedSizes().size());
9608f1650cbSNicolas Vasilache assert(static_cast<size_t>(t2.getRank()) ==
9618f1650cbSNicolas Vasilache afterTrimming.getMixedSizes().size());
962562f9e99SMatthias Springer
963562f9e99SMatthias Springer for (unsigned i = 0; i < t1.getRank(); ++i) {
964562f9e99SMatthias Springer // Skip static dimensions.
9658f1650cbSNicolas Vasilache if (!t1.isDynamicDim(i))
9668f1650cbSNicolas Vasilache continue;
967060208b4SMatthias Springer auto size1 = beforeSlice.getMixedSizes()[i];
968562f9e99SMatthias Springer auto size2 = afterTrimming.getMixedSizes()[i];
969562f9e99SMatthias Springer
970562f9e99SMatthias Springer // Case 1: Same value or same constant int.
9718f1650cbSNicolas Vasilache if (isEqualConstantIntOrValue(size1, size2))
9728f1650cbSNicolas Vasilache continue;
973562f9e99SMatthias Springer
974562f9e99SMatthias Springer // Other cases: Take a deeper look at defining ops of values.
975562f9e99SMatthias Springer auto v1 = size1.dyn_cast<Value>();
976562f9e99SMatthias Springer auto v2 = size2.dyn_cast<Value>();
9778f1650cbSNicolas Vasilache if (!v1 || !v2)
9788f1650cbSNicolas Vasilache return false;
979562f9e99SMatthias Springer
980562f9e99SMatthias Springer // Case 2: Both values are identical AffineMinOps. (Should not happen if
981562f9e99SMatthias Springer // CSE is run.)
982562f9e99SMatthias Springer auto minOp1 = v1.getDefiningOp<AffineMinOp>();
983562f9e99SMatthias Springer auto minOp2 = v2.getDefiningOp<AffineMinOp>();
9848f1650cbSNicolas Vasilache if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
9858f1650cbSNicolas Vasilache minOp1.operands() == minOp2.operands())
9868f1650cbSNicolas Vasilache continue;
987562f9e99SMatthias Springer
988562f9e99SMatthias Springer // Add additional cases as needed.
989562f9e99SMatthias Springer }
990562f9e99SMatthias Springer
991562f9e99SMatthias Springer // All tests passed.
992562f9e99SMatthias Springer return true;
993562f9e99SMatthias Springer }
994562f9e99SMatthias Springer };
995562f9e99SMatthias Springer
996fd0c6f53SAlexander Belyaev /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
997b1fd8a13SMatthias Springer /// ```
9981ad9b266Slorenzo chelini /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
999060208b4SMatthias Springer /// %r = tensor.insert_slice %0
1000060208b4SMatthias Springer /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
1001b1fd8a13SMatthias Springer /// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
1002b1fd8a13SMatthias Springer /// ```
1003b1fd8a13SMatthias Springer /// is rewritten to:
1004b1fd8a13SMatthias Springer /// ```
1005b1fd8a13SMatthias Springer /// %0 = vector.transfer_read %src[%c0, %c0], %padding
1006b1fd8a13SMatthias Springer /// : tensor<?x?xf32>, vector<17x5xf32>
1007b1fd8a13SMatthias Springer /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
1008b1fd8a13SMatthias Springer /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
1009b1fd8a13SMatthias Springer /// ```
1010b1fd8a13SMatthias Springer ///
1011b1fd8a13SMatthias Springer /// This rewrite is possible if:
1012b1fd8a13SMatthias Springer /// - Low padding is static 0.
1013b1fd8a13SMatthias Springer /// - `padOp` result shape is static.
1014b1fd8a13SMatthias Springer /// - The entire padded tensor is inserted.
1015b1fd8a13SMatthias Springer /// (Implies that sizes of `insertOp` are all static.)
1016b1fd8a13SMatthias Springer /// - Only unit strides in `insertOp`.
1017b1fd8a13SMatthias Springer /// - Single, scalar padding value.
10186859f8edSgysit /// - `padOp` result not used as destination.
1019fd0c6f53SAlexander Belyaev struct PadOpVectorizationWithInsertSlicePattern
1020fd0c6f53SAlexander Belyaev : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
1021fd0c6f53SAlexander Belyaev using VectorizePadOpUserPattern<
1022fd0c6f53SAlexander Belyaev tensor::InsertSliceOp>::VectorizePadOpUserPattern;
1023b1fd8a13SMatthias Springer
rewriteUserPadOpVectorizationWithInsertSlicePattern1024fd0c6f53SAlexander Belyaev LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
1025060208b4SMatthias Springer tensor::InsertSliceOp insertOp) const override {
1026b1fd8a13SMatthias Springer // Low padding must be static 0.
10278f1650cbSNicolas Vasilache if (!padOp.hasZeroLowPad())
10288f1650cbSNicolas Vasilache return failure();
1029b1fd8a13SMatthias Springer // Only unit stride supported.
10308f1650cbSNicolas Vasilache if (!insertOp.hasUnitStride())
10318f1650cbSNicolas Vasilache return failure();
1032b1fd8a13SMatthias Springer // Pad value must be a constant.
1033b1fd8a13SMatthias Springer auto padValue = padOp.getConstantPaddingValue();
1034b1fd8a13SMatthias Springer if (!padValue)
1035b1fd8a13SMatthias Springer return failure();
1036b1fd8a13SMatthias Springer // Dynamic shapes not supported.
103704235d07SJacques Pienaar if (!padOp.getResult().getType().cast<ShapedType>().hasStaticShape())
1038b1fd8a13SMatthias Springer return failure();
10396859f8edSgysit // Pad result not used as destination.
104004235d07SJacques Pienaar if (insertOp.getDest() == padOp.getResult())
10416859f8edSgysit return failure();
1042b1fd8a13SMatthias Springer
1043b1fd8a13SMatthias Springer auto vecType = VectorType::get(padOp.getType().getShape(),
1044b1fd8a13SMatthias Springer padOp.getType().getElementType());
1045b1fd8a13SMatthias Springer unsigned vecRank = vecType.getRank();
1046b1fd8a13SMatthias Springer unsigned tensorRank = insertOp.getType().getRank();
1047b1fd8a13SMatthias Springer
1048b1fd8a13SMatthias Springer // Check if sizes match: Insert the entire tensor into most minor dims.
1049b1fd8a13SMatthias Springer // (No permutations allowed.)
1050b1fd8a13SMatthias Springer SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
1051b1fd8a13SMatthias Springer expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
1052b1fd8a13SMatthias Springer if (!llvm::all_of(
10530813700dSMatthias Springer llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
10540813700dSMatthias Springer return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
10550813700dSMatthias Springer }))
1056b1fd8a13SMatthias Springer return failure();
1057b1fd8a13SMatthias Springer
10583f8f2923STobias Gysi // Insert the TransferReadOp and TransferWriteOp at the position of the
10593f8f2923STobias Gysi // InsertSliceOp.
10603f8f2923STobias Gysi rewriter.setInsertionPoint(insertOp);
10613f8f2923STobias Gysi
10629a7d111fSNicolas Vasilache // Generate TransferReadOp: Read entire source tensor and add high
10639a7d111fSNicolas Vasilache // padding.
1064b1fd8a13SMatthias Springer SmallVector<Value> readIndices(
1065a54f4eaeSMogball vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1066b1fd8a13SMatthias Springer auto read = rewriter.create<vector::TransferReadOp>(
106704235d07SJacques Pienaar padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
1068b1fd8a13SMatthias Springer
1069060208b4SMatthias Springer // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
1070060208b4SMatthias Springer // specified offsets. Write is fully in-bounds because a InsertSliceOp's
1071b1fd8a13SMatthias Springer // source must fit into the destination at the specified offsets.
1072b1fd8a13SMatthias Springer auto writeIndices =
1073b1fd8a13SMatthias Springer ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
1074b1fd8a13SMatthias Springer SmallVector<bool> inBounds(vecRank, true);
1075b1fd8a13SMatthias Springer rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
107604235d07SJacques Pienaar insertOp, read, insertOp.getDest(), writeIndices,
1077c537a943SNicolas Vasilache ArrayRef<bool>{inBounds});
1078b1fd8a13SMatthias Springer
1079b1fd8a13SMatthias Springer return success();
1080b1fd8a13SMatthias Springer }
1081b1fd8a13SMatthias Springer };
1082b1fd8a13SMatthias Springer
populatePadOpVectorizationPatterns(RewritePatternSet & patterns,PatternBenefit baseBenefit)1083fd0c6f53SAlexander Belyaev void mlir::linalg::populatePadOpVectorizationPatterns(
1084e789efc9SMatthias Springer RewritePatternSet &patterns, PatternBenefit baseBenefit) {
1085fd0c6f53SAlexander Belyaev patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
10868f1650cbSNicolas Vasilache baseBenefit);
1087b1b82271SMatthias Springer // Try these specialized patterns first before resorting to the generic one.
1088fd0c6f53SAlexander Belyaev patterns.add<PadOpVectorizationWithTransferReadPattern,
1089fd0c6f53SAlexander Belyaev PadOpVectorizationWithTransferWritePattern,
1090fd0c6f53SAlexander Belyaev PadOpVectorizationWithInsertSlicePattern>(
109198fff515SMatthias Springer patterns.getContext(), baseBenefit.getBenefit() + 1);
1092e789efc9SMatthias Springer }
1093bb69de3fSNicolas Vasilache
1094f245b7adSNicolas Vasilache //----------------------------------------------------------------------------//
1095f245b7adSNicolas Vasilache // Forwarding patterns
1096f245b7adSNicolas Vasilache //----------------------------------------------------------------------------//
1097f245b7adSNicolas Vasilache
10989a7d111fSNicolas Vasilache /// Check whether there is any interleaved use of any `values` between
10999a7d111fSNicolas Vasilache /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
11009a7d111fSNicolas Vasilache /// is in a different block.
mayExistInterleavedUses(Operation * firstOp,Operation * secondOp,ValueRange values)11011ee11432SNicolas Vasilache static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
11021ee11432SNicolas Vasilache ValueRange values) {
11031ee11432SNicolas Vasilache if (firstOp->getBlock() != secondOp->getBlock() ||
11041ee11432SNicolas Vasilache !firstOp->isBeforeInBlock(secondOp)) {
1105753a67b5SNicolas Vasilache LDBG("interleavedUses precondition failed, firstOp: "
11061ee11432SNicolas Vasilache << *firstOp << ", second op: " << *secondOp);
11071ee11432SNicolas Vasilache return true;
11081ee11432SNicolas Vasilache }
11091ee11432SNicolas Vasilache for (auto v : values) {
11101ee11432SNicolas Vasilache for (auto &u : v.getUses()) {
11111ee11432SNicolas Vasilache Operation *owner = u.getOwner();
11121ee11432SNicolas Vasilache if (owner == firstOp || owner == secondOp)
11131ee11432SNicolas Vasilache continue;
11141ee11432SNicolas Vasilache // TODO: this is too conservative, use dominance info in the future.
11151ee11432SNicolas Vasilache if (owner->getBlock() == firstOp->getBlock() &&
11161ee11432SNicolas Vasilache (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
11171ee11432SNicolas Vasilache continue;
1118753a67b5SNicolas Vasilache LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
1119753a67b5SNicolas Vasilache << ", second op: " << *secondOp);
11201ee11432SNicolas Vasilache return true;
11211ee11432SNicolas Vasilache }
11221ee11432SNicolas Vasilache }
11231ee11432SNicolas Vasilache return false;
11241ee11432SNicolas Vasilache }
11251ee11432SNicolas Vasilache
11269a7d111fSNicolas Vasilache /// Return the unique subview use of `v` if it is indeed unique, null
11279a7d111fSNicolas Vasilache /// otherwise.
getSubViewUseIfUnique(Value v)1128e2310704SJulian Gross static memref::SubViewOp getSubViewUseIfUnique(Value v) {
1129e2310704SJulian Gross memref::SubViewOp subViewOp;
11301ee11432SNicolas Vasilache for (auto &u : v.getUses()) {
1131e2310704SJulian Gross if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
11321ee11432SNicolas Vasilache if (subViewOp)
1133e2310704SJulian Gross return memref::SubViewOp();
11341ee11432SNicolas Vasilache subViewOp = newSubViewOp;
11351ee11432SNicolas Vasilache }
11361ee11432SNicolas Vasilache }
11371ee11432SNicolas Vasilache return subViewOp;
11381ee11432SNicolas Vasilache }
11391ee11432SNicolas Vasilache
11401ee11432SNicolas Vasilache /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
11411ee11432SNicolas Vasilache /// when available.
matchAndRewrite(vector::TransferReadOp xferOp,PatternRewriter & rewriter) const11421ee11432SNicolas Vasilache LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
11431ee11432SNicolas Vasilache vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
11441ee11432SNicolas Vasilache
1145c537a943SNicolas Vasilache // TODO: support mask.
11467c38fd60SJacques Pienaar if (xferOp.getMask())
1147c537a943SNicolas Vasilache return failure();
1148c537a943SNicolas Vasilache
11491ee11432SNicolas Vasilache // Transfer into `view`.
11507c38fd60SJacques Pienaar Value viewOrAlloc = xferOp.getSource();
1151e2310704SJulian Gross if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1152e2310704SJulian Gross !viewOrAlloc.getDefiningOp<memref::AllocOp>())
11531ee11432SNicolas Vasilache return failure();
11541ee11432SNicolas Vasilache
1155753a67b5SNicolas Vasilache LDBG(viewOrAlloc);
11561ee11432SNicolas Vasilache
11571ee11432SNicolas Vasilache // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1158e2310704SJulian Gross memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
11591ee11432SNicolas Vasilache if (!subViewOp)
11601ee11432SNicolas Vasilache return failure();
11611ee11432SNicolas Vasilache Value subView = subViewOp.getResult();
1162753a67b5SNicolas Vasilache LDBG("with subView " << subView);
11631ee11432SNicolas Vasilache
11641ee11432SNicolas Vasilache // Find the copy into `subView` without interleaved uses.
1165ebc81537SAlexander Belyaev memref::CopyOp copyOp;
11661ee11432SNicolas Vasilache for (auto &u : subView.getUses()) {
1167ebc81537SAlexander Belyaev if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
1168136d746eSJacques Pienaar assert(newCopyOp.getTarget().getType().isa<MemRefType>());
1169136d746eSJacques Pienaar if (newCopyOp.getTarget() != subView)
11701ee11432SNicolas Vasilache continue;
1171753a67b5SNicolas Vasilache LDBG("copy candidate " << *newCopyOp);
11721ee11432SNicolas Vasilache if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
11731ee11432SNicolas Vasilache continue;
11741ee11432SNicolas Vasilache copyOp = newCopyOp;
11751ee11432SNicolas Vasilache break;
11761ee11432SNicolas Vasilache }
11771ee11432SNicolas Vasilache }
11781ee11432SNicolas Vasilache if (!copyOp)
11791ee11432SNicolas Vasilache return failure();
1180753a67b5SNicolas Vasilache LDBG("with copy " << *copyOp);
11811ee11432SNicolas Vasilache
11829a7d111fSNicolas Vasilache // Find the fill into `viewOrAlloc` without interleaved uses before the
11839a7d111fSNicolas Vasilache // copy.
11841ee11432SNicolas Vasilache FillOp maybeFillOp;
11851ee11432SNicolas Vasilache for (auto &u : viewOrAlloc.getUses()) {
11861ee11432SNicolas Vasilache if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
1187912ebf60STobias Gysi assert(newFillOp.output().getType().isa<MemRefType>());
1188912ebf60STobias Gysi if (newFillOp.output() != viewOrAlloc)
11891ee11432SNicolas Vasilache continue;
1190753a67b5SNicolas Vasilache LDBG("fill candidate " << *newFillOp);
11911ee11432SNicolas Vasilache if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
11921ee11432SNicolas Vasilache continue;
11931ee11432SNicolas Vasilache maybeFillOp = newFillOp;
11941ee11432SNicolas Vasilache break;
11951ee11432SNicolas Vasilache }
11961ee11432SNicolas Vasilache }
11971ee11432SNicolas Vasilache // Ensure padding matches.
11987c38fd60SJacques Pienaar if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
11991ee11432SNicolas Vasilache return failure();
12001ee11432SNicolas Vasilache if (maybeFillOp)
1201753a67b5SNicolas Vasilache LDBG("with maybeFillOp " << *maybeFillOp);
12021ee11432SNicolas Vasilache
1203ebc81537SAlexander Belyaev // `in` is the subview that memref.copy reads. Replace it.
1204136d746eSJacques Pienaar Value in = copyOp.getSource();
12051ee11432SNicolas Vasilache
1206ebc81537SAlexander Belyaev // memref.copy + linalg.fill can be used to create a padded local buffer.
1207512da70bSNicolas Vasilache // The `masked` attribute is only valid on this padded buffer.
1208512da70bSNicolas Vasilache // When forwarding to vector.transfer_read, the attribute must be reset
1209512da70bSNicolas Vasilache // conservatively.
12101ee11432SNicolas Vasilache Value res = rewriter.create<vector::TransferReadOp>(
12117c38fd60SJacques Pienaar xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
12127c38fd60SJacques Pienaar xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
1213c537a943SNicolas Vasilache // in_bounds is explicitly reset
1214c537a943SNicolas Vasilache /*inBoundsAttr=*/ArrayAttr());
12151ee11432SNicolas Vasilache
12161ee11432SNicolas Vasilache if (maybeFillOp)
12171ee11432SNicolas Vasilache rewriter.eraseOp(maybeFillOp);
12181ee11432SNicolas Vasilache rewriter.eraseOp(copyOp);
12191ee11432SNicolas Vasilache rewriter.replaceOp(xferOp, res);
12201ee11432SNicolas Vasilache
12211ee11432SNicolas Vasilache return success();
12221ee11432SNicolas Vasilache }
12231ee11432SNicolas Vasilache
12241ee11432SNicolas Vasilache /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
12251ee11432SNicolas Vasilache /// when available.
matchAndRewrite(vector::TransferWriteOp xferOp,PatternRewriter & rewriter) const12261ee11432SNicolas Vasilache LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
12271ee11432SNicolas Vasilache vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
1228c537a943SNicolas Vasilache // TODO: support mask.
12297c38fd60SJacques Pienaar if (xferOp.getMask())
1230c537a943SNicolas Vasilache return failure();
1231c537a943SNicolas Vasilache
12321ee11432SNicolas Vasilache // Transfer into `viewOrAlloc`.
12337c38fd60SJacques Pienaar Value viewOrAlloc = xferOp.getSource();
1234e2310704SJulian Gross if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1235e2310704SJulian Gross !viewOrAlloc.getDefiningOp<memref::AllocOp>())
12361ee11432SNicolas Vasilache return failure();
12371ee11432SNicolas Vasilache
12381ee11432SNicolas Vasilache // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1239e2310704SJulian Gross memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
12401ee11432SNicolas Vasilache if (!subViewOp)
12411ee11432SNicolas Vasilache return failure();
12421ee11432SNicolas Vasilache Value subView = subViewOp.getResult();
12431ee11432SNicolas Vasilache
12441ee11432SNicolas Vasilache // Find the copy from `subView` without interleaved uses.
1245ebc81537SAlexander Belyaev memref::CopyOp copyOp;
12461ee11432SNicolas Vasilache for (auto &u : subViewOp.getResult().getUses()) {
1247ebc81537SAlexander Belyaev if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
1248136d746eSJacques Pienaar if (newCopyOp.getSource() != subView)
12491ee11432SNicolas Vasilache continue;
12501ee11432SNicolas Vasilache if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
12511ee11432SNicolas Vasilache continue;
12521ee11432SNicolas Vasilache copyOp = newCopyOp;
12531ee11432SNicolas Vasilache break;
12541ee11432SNicolas Vasilache }
12551ee11432SNicolas Vasilache }
12561ee11432SNicolas Vasilache if (!copyOp)
12571ee11432SNicolas Vasilache return failure();
12581ee11432SNicolas Vasilache
12591ee11432SNicolas Vasilache // `out` is the subview copied into that we replace.
1260136d746eSJacques Pienaar assert(copyOp.getTarget().getType().isa<MemRefType>());
1261136d746eSJacques Pienaar Value out = copyOp.getTarget();
12621ee11432SNicolas Vasilache
12631ee11432SNicolas Vasilache // Forward vector.transfer into copy.
1264ebc81537SAlexander Belyaev // memref.copy + linalg.fill can be used to create a padded local buffer.
1265512da70bSNicolas Vasilache // The `masked` attribute is only valid on this padded buffer.
1266512da70bSNicolas Vasilache // When forwarding to vector.transfer_write, the attribute must be reset
1267512da70bSNicolas Vasilache // conservatively.
12681ee11432SNicolas Vasilache rewriter.create<vector::TransferWriteOp>(
12697c38fd60SJacques Pienaar xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
12707c38fd60SJacques Pienaar xferOp.getPermutationMapAttr(), xferOp.getMask(),
1271c537a943SNicolas Vasilache // in_bounds is explicitly reset
1272c537a943SNicolas Vasilache /*inBoundsAttr=*/ArrayAttr());
12731ee11432SNicolas Vasilache
12741ee11432SNicolas Vasilache rewriter.eraseOp(copyOp);
12751ee11432SNicolas Vasilache rewriter.eraseOp(xferOp);
12761ee11432SNicolas Vasilache
12771ee11432SNicolas Vasilache return success();
12781ee11432SNicolas Vasilache }
12796bb7d247SNicolas Vasilache
12806bb7d247SNicolas Vasilache //===----------------------------------------------------------------------===//
12816bb7d247SNicolas Vasilache // Convolution vectorization patterns
12826bb7d247SNicolas Vasilache //===----------------------------------------------------------------------===//
1283392e16c2SNicolas Vasilache
1284392e16c2SNicolas Vasilache template <int N>
bindShapeDims(ShapedType shapedType)1285392e16c2SNicolas Vasilache static void bindShapeDims(ShapedType shapedType) {}
1286392e16c2SNicolas Vasilache
1287392e16c2SNicolas Vasilache template <int N, typename IntTy, typename... IntTy2>
bindShapeDims(ShapedType shapedType,IntTy & val,IntTy2 &...vals)1288392e16c2SNicolas Vasilache static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
1289392e16c2SNicolas Vasilache val = shapedType.getShape()[N];
1290392e16c2SNicolas Vasilache bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
1291392e16c2SNicolas Vasilache }
1292392e16c2SNicolas Vasilache
1293392e16c2SNicolas Vasilache /// Bind a pack of int& to the leading dimensions of shapedType.getShape().
1294392e16c2SNicolas Vasilache template <typename... IntTy>
bindShapeDims(ShapedType shapedType,IntTy &...vals)1295392e16c2SNicolas Vasilache static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
1296392e16c2SNicolas Vasilache bindShapeDims<0>(shapedType, vals...);
1297392e16c2SNicolas Vasilache }
1298392e16c2SNicolas Vasilache
12996bb7d247SNicolas Vasilache namespace {
130099ff697bSNicolas Vasilache /// Generate a vector implementation for either:
13016bb7d247SNicolas Vasilache /// ```
13026bb7d247SNicolas Vasilache /// Op def: ( n, w, c, kw, f )
13036bb7d247SNicolas Vasilache /// Iters: ({Par(), Par(), Par(), Red(), Red()})
13046bb7d247SNicolas Vasilache /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
13056bb7d247SNicolas Vasilache /// ```
13067b09f157SNicolas Vasilache /// kw is unrolled, w is unrolled iff dilationW > 1.
130799ff697bSNicolas Vasilache ///
130899ff697bSNicolas Vasilache /// or
130999ff697bSNicolas Vasilache ///
131099ff697bSNicolas Vasilache /// ```
131199ff697bSNicolas Vasilache /// Op def: ( n, w, c, kw )
131299ff697bSNicolas Vasilache /// Iters: ({Par(), Par(), Par(), Red()})
131399ff697bSNicolas Vasilache /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
131499ff697bSNicolas Vasilache /// ```
131599ff697bSNicolas Vasilache /// kw is unrolled, w is unrolled iff dilationW > 1.
131602b6fb21SMehdi Amini struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
Conv1DNwcGenerator__anona007eb931511::Conv1DNwcGenerator131702b6fb21SMehdi Amini Conv1DNwcGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
13186bb7d247SNicolas Vasilache int dilationW)
1319671e30a1SMehdi Amini : StructuredGenerator<LinalgOp>(builder, linalgOp), strideW(strideW),
1320671e30a1SMehdi Amini dilationW(dilationW) {
13216bb7d247SNicolas Vasilache // Determine whether `linalgOp` can be generated with this generator
13226bb7d247SNicolas Vasilache if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
13236bb7d247SNicolas Vasilache return;
13246bb7d247SNicolas Vasilache lhsShaped = linalgOp.inputs()[0];
13256bb7d247SNicolas Vasilache rhsShaped = linalgOp.inputs()[1];
13266bb7d247SNicolas Vasilache resShaped = linalgOp.outputs()[0];
13276bb7d247SNicolas Vasilache lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
13286bb7d247SNicolas Vasilache rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
13296bb7d247SNicolas Vasilache resShapedType = resShaped.getType().dyn_cast<ShapedType>();
13306bb7d247SNicolas Vasilache if (!lhsShapedType || !rhsShapedType || !resShapedType)
13316bb7d247SNicolas Vasilache return;
133299ff697bSNicolas Vasilache if (lhsShapedType.getRank() != 3 ||
133399ff697bSNicolas Vasilache (rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) ||
13346bb7d247SNicolas Vasilache resShapedType.getRank() != 3)
13356bb7d247SNicolas Vasilache return;
13366bb7d247SNicolas Vasilache
13376bb7d247SNicolas Vasilache // Check for reduction `add` preceded by `mul`.
13386bb7d247SNicolas Vasilache Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0));
13396bb7d247SNicolas Vasilache if (!reduceOp)
13406bb7d247SNicolas Vasilache return;
13416bb7d247SNicolas Vasilache llvm::Optional<vector::CombiningKind> maybeKind;
1342436d17a8SAlexander Belyaev maybeKind = getCombinerOpKind(reduceOp);
13436bb7d247SNicolas Vasilache if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
13446bb7d247SNicolas Vasilache return;
1345046ebeb6SThomas Raoux // Check for single `mul` predecessor. The `mul` operands must be block
1346046ebeb6SThomas Raoux // arguments or extension of block arguments.
1347046ebeb6SThomas Raoux Operation *mulOp = nullptr;
1348046ebeb6SThomas Raoux for (Value operand : reduceOp->getOperands()) {
1349046ebeb6SThomas Raoux if (operand.isa<BlockArgument>())
1350046ebeb6SThomas Raoux continue;
1351046ebeb6SThomas Raoux if (mulOp)
13526bb7d247SNicolas Vasilache return;
1353046ebeb6SThomas Raoux mulOp = operand.getDefiningOp();
1354046ebeb6SThomas Raoux if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
1355046ebeb6SThomas Raoux return;
1356046ebeb6SThomas Raoux }
1357046ebeb6SThomas Raoux if (!mulOp)
1358046ebeb6SThomas Raoux return;
1359046ebeb6SThomas Raoux for (Value operand : mulOp->getOperands()) {
1360046ebeb6SThomas Raoux if (Operation *def = operand.getDefiningOp()) {
1361046ebeb6SThomas Raoux if (!isa<arith::ExtFOp>(def))
1362046ebeb6SThomas Raoux return;
1363046ebeb6SThomas Raoux operand = def->getOperand(0);
1364046ebeb6SThomas Raoux }
1365046ebeb6SThomas Raoux if (!operand.isa<BlockArgument>())
1366046ebeb6SThomas Raoux return;
1367046ebeb6SThomas Raoux }
13686bb7d247SNicolas Vasilache // The op is now known to be valid.
13696bb7d247SNicolas Vasilache valid = true;
13706bb7d247SNicolas Vasilache }
13716bb7d247SNicolas Vasilache
13726bb7d247SNicolas Vasilache /// Generate a vector implementation for:
13736bb7d247SNicolas Vasilache /// ```
13746bb7d247SNicolas Vasilache /// Op def: ( n, w, c, kw, f )
13756bb7d247SNicolas Vasilache /// Iters: ({Par(), Par(), Par(), Red(), Red()})
13766bb7d247SNicolas Vasilache /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
13776bb7d247SNicolas Vasilache /// ```
1378203accf0SNicolas Vasilache /// kw is always unrolled.
13799a7d111fSNicolas Vasilache /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
13809a7d111fSNicolas Vasilache /// > 1.
conv__anona007eb931511::Conv1DNwcGenerator13816bb7d247SNicolas Vasilache FailureOr<Operation *> conv() {
13826bb7d247SNicolas Vasilache if (!valid)
13836bb7d247SNicolas Vasilache return failure();
13846bb7d247SNicolas Vasilache
1385392e16c2SNicolas Vasilache int64_t nSize, wSize, cSize, kwSize, fSize;
1386392e16c2SNicolas Vasilache // kernel{kw, c, f}
1387392e16c2SNicolas Vasilache bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
1388392e16c2SNicolas Vasilache // out{n, w, f}
1389392e16c2SNicolas Vasilache bindShapeDims(resShapedType, nSize, wSize);
13906bb7d247SNicolas Vasilache
13916bb7d247SNicolas Vasilache vector::TransferWriteOp write;
13926bb7d247SNicolas Vasilache Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
13936bb7d247SNicolas Vasilache
13947b09f157SNicolas Vasilache // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
13959a7d111fSNicolas Vasilache // When strideW == 1, we can batch the contiguous loads and avoid
13969a7d111fSNicolas Vasilache // unrolling
1397203accf0SNicolas Vasilache int64_t wSizeStep = strideW == 1 ? wSize : 1;
1398203accf0SNicolas Vasilache
13999c497174SNicolas Vasilache Type lhsEltType = lhsShapedType.getElementType();
14009c497174SNicolas Vasilache Type rhsEltType = rhsShapedType.getElementType();
14019c497174SNicolas Vasilache Type resEltType = resShapedType.getElementType();
14029c497174SNicolas Vasilache VectorType lhsType = VectorType::get(
1403f1c86b83SNicolas Vasilache {nSize,
1404f1c86b83SNicolas Vasilache // iw = ow * sw + kw * dw - 1
1405f1c86b83SNicolas Vasilache // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1406641fe707SNicolas Vasilache // Perform the proper inclusive -> exclusive -> inclusive.
1407f1c86b83SNicolas Vasilache ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
14089c497174SNicolas Vasilache cSize},
14099c497174SNicolas Vasilache lhsEltType);
14109c497174SNicolas Vasilache VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
14119c497174SNicolas Vasilache VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
14127b09f157SNicolas Vasilache
14139a7d111fSNicolas Vasilache // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
14149a7d111fSNicolas Vasilache // 0].
14159c497174SNicolas Vasilache Value lhs = builder.create<vector::TransferReadOp>(
14169c497174SNicolas Vasilache loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
14179c497174SNicolas Vasilache // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
14189c497174SNicolas Vasilache Value rhs = builder.create<vector::TransferReadOp>(
14199c497174SNicolas Vasilache loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
14209c497174SNicolas Vasilache // Read res slice of size {n, w, f} @ [0, 0, 0].
14219c497174SNicolas Vasilache Value res = builder.create<vector::TransferReadOp>(
14229c497174SNicolas Vasilache loc, resType, resShaped, ValueRange{zero, zero, zero});
14239c497174SNicolas Vasilache
14249c497174SNicolas Vasilache //===------------------------------------------------------------------===//
14259c497174SNicolas Vasilache // Begin vector-only rewrite part
14269c497174SNicolas Vasilache //===------------------------------------------------------------------===//
14276bb7d247SNicolas Vasilache // Unroll along kw and read slices of lhs and rhs.
14289c497174SNicolas Vasilache SmallVector<Value> lhsVals, rhsVals, resVals;
1429392e16c2SNicolas Vasilache // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
14306bb7d247SNicolas Vasilache for (int64_t kw = 0; kw < kwSize; ++kw) {
14319c497174SNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
14329c497174SNicolas Vasilache lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
14339c497174SNicolas Vasilache loc, lhs,
14349c497174SNicolas Vasilache /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
14359c497174SNicolas Vasilache /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
14369c497174SNicolas Vasilache /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1437392e16c2SNicolas Vasilache }
1438392e16c2SNicolas Vasilache }
1439392e16c2SNicolas Vasilache // Extract rhs slice of size {c, f} @ [kw].
1440392e16c2SNicolas Vasilache for (int64_t kw = 0; kw < kwSize; ++kw) {
1441392e16c2SNicolas Vasilache rhsVals.push_back(builder.create<vector::ExtractOp>(
1442392e16c2SNicolas Vasilache loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
1443392e16c2SNicolas Vasilache }
14449c497174SNicolas Vasilache // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
1445392e16c2SNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
14469c497174SNicolas Vasilache resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
14479c497174SNicolas Vasilache loc, res,
14489c497174SNicolas Vasilache /*offsets=*/ArrayRef<int64_t>{0, w, 0},
14499c497174SNicolas Vasilache /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
14509c497174SNicolas Vasilache /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
14517b09f157SNicolas Vasilache }
14529c497174SNicolas Vasilache
14539c497174SNicolas Vasilache auto linearIndex = [&](int64_t kw, int64_t w) {
14549c497174SNicolas Vasilache return kw * (wSize / wSizeStep) + w;
14559c497174SNicolas Vasilache };
14569c497174SNicolas Vasilache
14579c497174SNicolas Vasilache // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
14587b09f157SNicolas Vasilache for (int64_t kw = 0; kw < kwSize; ++kw) {
14599c497174SNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
14609c497174SNicolas Vasilache resVals[w] = conv1dSliceAsContraction(
14619c497174SNicolas Vasilache builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
14627b09f157SNicolas Vasilache }
14637b09f157SNicolas Vasilache }
14649c497174SNicolas Vasilache
1465203accf0SNicolas Vasilache // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
14669c497174SNicolas Vasilache // This does not depend on kw.
14679c497174SNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
14689c497174SNicolas Vasilache res = builder.create<vector::InsertStridedSliceOp>(
14699c497174SNicolas Vasilache loc, resVals[w], res,
14709c497174SNicolas Vasilache /*offsets=*/ArrayRef<int64_t>{0, w, 0},
14719c497174SNicolas Vasilache /*strides=*/ArrayRef<int64_t>{1, 1, 1});
14726bb7d247SNicolas Vasilache }
14739c497174SNicolas Vasilache //===------------------------------------------------------------------===//
14749c497174SNicolas Vasilache // End vector-only rewrite part
14759c497174SNicolas Vasilache //===------------------------------------------------------------------===//
14766bb7d247SNicolas Vasilache
14779c497174SNicolas Vasilache // Write back res slice of size {n, w, f} @ [0, 0, 0].
14789c497174SNicolas Vasilache return builder
14799c497174SNicolas Vasilache .create<vector::TransferWriteOp>(loc, res, resShaped,
14809c497174SNicolas Vasilache ValueRange{zero, zero, zero})
14819c497174SNicolas Vasilache .getOperation();
14826bb7d247SNicolas Vasilache }
14836bb7d247SNicolas Vasilache
14847b09f157SNicolas Vasilache // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
conv1dSliceAsContraction__anona007eb931511::Conv1DNwcGenerator1485641fe707SNicolas Vasilache Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
1486641fe707SNicolas Vasilache Value rhs, Value res) {
14877b09f157SNicolas Vasilache StringRef par = Par().strRef, red = Red().strRef;
14887b09f157SNicolas Vasilache AffineExpr n, w, f, c;
14897b09f157SNicolas Vasilache bindDims(ctx, n, w, f, c);
14907b09f157SNicolas Vasilache return builder.create<vector::ContractionOp>(
14917b09f157SNicolas Vasilache loc, lhs, rhs, res,
14927b09f157SNicolas Vasilache /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
14937b09f157SNicolas Vasilache /*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
14947b09f157SNicolas Vasilache }
14957b09f157SNicolas Vasilache
149699ff697bSNicolas Vasilache /// Generate a vector implementation for:
149799ff697bSNicolas Vasilache /// ```
149899ff697bSNicolas Vasilache /// Op def: ( n, w, c, kw)
149999ff697bSNicolas Vasilache /// Iters: ({Par(), Par(), Par(), Red()})
150099ff697bSNicolas Vasilache /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
150199ff697bSNicolas Vasilache /// ```
150299ff697bSNicolas Vasilache /// kw is always unrolled.
15039a7d111fSNicolas Vasilache /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
15049a7d111fSNicolas Vasilache /// > 1.
depthwiseConv__anona007eb931511::Conv1DNwcGenerator1505392e16c2SNicolas Vasilache FailureOr<Operation *> depthwiseConv() {
150699ff697bSNicolas Vasilache if (!valid)
150799ff697bSNicolas Vasilache return failure();
150899ff697bSNicolas Vasilache
1509392e16c2SNicolas Vasilache int64_t nSize, wSize, cSize, kwSize;
1510392e16c2SNicolas Vasilache // kernel{kw, c}
1511392e16c2SNicolas Vasilache bindShapeDims(rhsShapedType, kwSize, cSize);
1512392e16c2SNicolas Vasilache // out{n, w, c}
1513392e16c2SNicolas Vasilache bindShapeDims(resShapedType, nSize, wSize);
151499ff697bSNicolas Vasilache
151599ff697bSNicolas Vasilache vector::TransferWriteOp write;
151699ff697bSNicolas Vasilache Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
151799ff697bSNicolas Vasilache
151899ff697bSNicolas Vasilache // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
15199a7d111fSNicolas Vasilache // When strideW == 1, we can batch the contiguous loads and avoid
15209a7d111fSNicolas Vasilache // unrolling
152199ff697bSNicolas Vasilache int64_t wSizeStep = strideW == 1 ? wSize : 1;
152299ff697bSNicolas Vasilache
152399ff697bSNicolas Vasilache Type lhsEltType = lhsShapedType.getElementType();
152499ff697bSNicolas Vasilache Type rhsEltType = rhsShapedType.getElementType();
152599ff697bSNicolas Vasilache Type resEltType = resShapedType.getElementType();
152699ff697bSNicolas Vasilache VectorType lhsType = VectorType::get(
1527641fe707SNicolas Vasilache {nSize,
1528641fe707SNicolas Vasilache // iw = ow * sw + kw * dw - 1
1529641fe707SNicolas Vasilache // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1530641fe707SNicolas Vasilache ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
153199ff697bSNicolas Vasilache cSize},
153299ff697bSNicolas Vasilache lhsEltType);
153399ff697bSNicolas Vasilache VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
153499ff697bSNicolas Vasilache VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
153599ff697bSNicolas Vasilache
15369a7d111fSNicolas Vasilache // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
15379a7d111fSNicolas Vasilache // 0].
153899ff697bSNicolas Vasilache Value lhs = builder.create<vector::TransferReadOp>(
153999ff697bSNicolas Vasilache loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
154099ff697bSNicolas Vasilache // Read rhs slice of size {kw, c} @ [0, 0].
154199ff697bSNicolas Vasilache Value rhs = builder.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
154299ff697bSNicolas Vasilache ValueRange{zero, zero});
154399ff697bSNicolas Vasilache // Read res slice of size {n, w, c} @ [0, 0, 0].
154499ff697bSNicolas Vasilache Value res = builder.create<vector::TransferReadOp>(
154599ff697bSNicolas Vasilache loc, resType, resShaped, ValueRange{zero, zero, zero});
154699ff697bSNicolas Vasilache
154799ff697bSNicolas Vasilache //===------------------------------------------------------------------===//
154899ff697bSNicolas Vasilache // Begin vector-only rewrite part
154999ff697bSNicolas Vasilache //===------------------------------------------------------------------===//
155099ff697bSNicolas Vasilache // Unroll along kw and read slices of lhs and rhs.
155199ff697bSNicolas Vasilache SmallVector<Value> lhsVals, rhsVals, resVals;
155299ff697bSNicolas Vasilache // Extract lhs slice of size {n, wSizeStep, c}
155399ff697bSNicolas Vasilache // @ [0, sw * w + dw * kw, 0].
1554392e16c2SNicolas Vasilache for (int64_t kw = 0; kw < kwSize; ++kw) {
1555392e16c2SNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
155699ff697bSNicolas Vasilache lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
155799ff697bSNicolas Vasilache loc, lhs,
155899ff697bSNicolas Vasilache /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
155999ff697bSNicolas Vasilache /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
156099ff697bSNicolas Vasilache /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1561392e16c2SNicolas Vasilache }
1562392e16c2SNicolas Vasilache }
1563392e16c2SNicolas Vasilache // Extract rhs slice of size {c} @ [kw].
1564392e16c2SNicolas Vasilache for (int64_t kw = 0; kw < kwSize; ++kw) {
1565392e16c2SNicolas Vasilache rhsVals.push_back(builder.create<vector::ExtractOp>(
1566392e16c2SNicolas Vasilache loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
1567392e16c2SNicolas Vasilache }
156899ff697bSNicolas Vasilache // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
1569392e16c2SNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
157099ff697bSNicolas Vasilache resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
157199ff697bSNicolas Vasilache loc, res,
157299ff697bSNicolas Vasilache /*offsets=*/ArrayRef<int64_t>{0, w, 0},
157399ff697bSNicolas Vasilache /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
157499ff697bSNicolas Vasilache /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
157599ff697bSNicolas Vasilache }
157699ff697bSNicolas Vasilache
157799ff697bSNicolas Vasilache auto linearIndex = [&](int64_t kw, int64_t w) {
157899ff697bSNicolas Vasilache return kw * (wSize / wSizeStep) + w;
157999ff697bSNicolas Vasilache };
158099ff697bSNicolas Vasilache
158199ff697bSNicolas Vasilache // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
158299ff697bSNicolas Vasilache for (int64_t kw = 0; kw < kwSize; ++kw) {
158399ff697bSNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
1584392e16c2SNicolas Vasilache resVals[w] = depthwiseConv1dSliceAsFma(
158599ff697bSNicolas Vasilache builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
158699ff697bSNicolas Vasilache }
158799ff697bSNicolas Vasilache }
158899ff697bSNicolas Vasilache
158999ff697bSNicolas Vasilache // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
159099ff697bSNicolas Vasilache // This does not depend on kw.
159199ff697bSNicolas Vasilache for (int64_t w = 0; w < wSize; w += wSizeStep) {
159299ff697bSNicolas Vasilache res = builder.create<vector::InsertStridedSliceOp>(
159399ff697bSNicolas Vasilache loc, resVals[w], res,
159499ff697bSNicolas Vasilache /*offsets=*/ArrayRef<int64_t>{0, w, 0},
159599ff697bSNicolas Vasilache /*strides=*/ArrayRef<int64_t>{1, 1, 1});
159699ff697bSNicolas Vasilache }
159799ff697bSNicolas Vasilache //===------------------------------------------------------------------===//
159899ff697bSNicolas Vasilache // End vector-only rewrite part
159999ff697bSNicolas Vasilache //===------------------------------------------------------------------===//
160099ff697bSNicolas Vasilache
160199ff697bSNicolas Vasilache // Write back res slice of size {n, w, c} @ [0, 0, 0].
160299ff697bSNicolas Vasilache return builder
160399ff697bSNicolas Vasilache .create<vector::TransferWriteOp>(loc, res, resShaped,
160499ff697bSNicolas Vasilache ValueRange{zero, zero, zero})
160599ff697bSNicolas Vasilache .getOperation();
160699ff697bSNicolas Vasilache }
160799ff697bSNicolas Vasilache
1608641fe707SNicolas Vasilache /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
depthwiseConv1dSliceAsFma__anona007eb931511::Conv1DNwcGenerator1609392e16c2SNicolas Vasilache Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
161099ff697bSNicolas Vasilache Value rhs, Value res) {
1611641fe707SNicolas Vasilache Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
1612641fe707SNicolas Vasilache return b.create<vector::FMAOp>(loc, lhs, bcast, res);
161399ff697bSNicolas Vasilache }
161499ff697bSNicolas Vasilache
16156bb7d247SNicolas Vasilache /// Entry point that transposes into the common form:
16166bb7d247SNicolas Vasilache /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
generateConv__anona007eb931511::Conv1DNwcGenerator16176bb7d247SNicolas Vasilache FailureOr<Operation *> generateConv() {
16186bb7d247SNicolas Vasilache AffineExpr n, w, f, kw, c;
16196bb7d247SNicolas Vasilache bindDims(ctx, n, w, f, kw, c);
16206bb7d247SNicolas Vasilache if (!iters({Par(), Par(), Par(), Red(), Red()}))
16216bb7d247SNicolas Vasilache return failure();
16226bb7d247SNicolas Vasilache
16236bb7d247SNicolas Vasilache // No transposition needed.
16246bb7d247SNicolas Vasilache if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
16256bb7d247SNicolas Vasilache /*rhsIndex*/ {kw, c, f},
16266bb7d247SNicolas Vasilache /*resIndex*/ {n, w, f}}))
16276bb7d247SNicolas Vasilache return conv();
16286bb7d247SNicolas Vasilache return failure();
16296bb7d247SNicolas Vasilache }
16306bb7d247SNicolas Vasilache
163199ff697bSNicolas Vasilache /// Entry point that transposes into the common form:
163299ff697bSNicolas Vasilache /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
generateDilatedConv__anona007eb931511::Conv1DNwcGenerator163399ff697bSNicolas Vasilache FailureOr<Operation *> generateDilatedConv() {
163499ff697bSNicolas Vasilache AffineExpr n, w, c, kw;
163599ff697bSNicolas Vasilache bindDims(ctx, n, w, c, kw);
163699ff697bSNicolas Vasilache if (!iters({Par(), Par(), Par(), Red()}))
163799ff697bSNicolas Vasilache return failure();
163899ff697bSNicolas Vasilache
163999ff697bSNicolas Vasilache // No transposition needed.
164099ff697bSNicolas Vasilache if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
164199ff697bSNicolas Vasilache /*rhsIndex*/ {kw, c},
164299ff697bSNicolas Vasilache /*resIndex*/ {n, w, c}}))
1643392e16c2SNicolas Vasilache return depthwiseConv();
164499ff697bSNicolas Vasilache return failure();
164599ff697bSNicolas Vasilache }
164699ff697bSNicolas Vasilache
16476bb7d247SNicolas Vasilache private:
1648671e30a1SMehdi Amini bool valid = false;
16496bb7d247SNicolas Vasilache int strideW, dilationW;
16506bb7d247SNicolas Vasilache Value lhsShaped, rhsShaped, resShaped;
16516bb7d247SNicolas Vasilache ShapedType lhsShapedType, rhsShapedType, resShapedType;
16526bb7d247SNicolas Vasilache };
16536bb7d247SNicolas Vasilache } // namespace
16546bb7d247SNicolas Vasilache
1655efdd4c16SNicolas Vasilache /// Helper function to vectorize a LinalgOp with convolution semantics.
16566bb7d247SNicolas Vasilache // TODO: extend the generic vectorization to support windows and drop this.
vectorizeConvolution(OpBuilder & b,LinalgOp op)1657efdd4c16SNicolas Vasilache static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
1658efdd4c16SNicolas Vasilache // The ConvolutionOpInterface gives us guarantees of existence for
1659efdd4c16SNicolas Vasilache // strides/dilations. However, we do not need to rely on those, we can simply
1660efdd4c16SNicolas Vasilache // use them if present, otherwise use the default and let the generic conv.
1661efdd4c16SNicolas Vasilache // matcher in the ConvGenerator succeed or fail.
1662efdd4c16SNicolas Vasilache auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
1663efdd4c16SNicolas Vasilache auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
16646bb7d247SNicolas Vasilache auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
16656bb7d247SNicolas Vasilache auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
1666efdd4c16SNicolas Vasilache Conv1DNwcGenerator e(b, op, stride, dilation);
166799ff697bSNicolas Vasilache auto res = e.generateConv();
166899ff697bSNicolas Vasilache if (succeeded(res))
166999ff697bSNicolas Vasilache return res;
167099ff697bSNicolas Vasilache return e.generateDilatedConv();
16716bb7d247SNicolas Vasilache }
16726bb7d247SNicolas Vasilache
1673efdd4c16SNicolas Vasilache struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
16746bb7d247SNicolas Vasilache using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
16756bb7d247SNicolas Vasilache
matchAndRewriteVectorizeConvolution1676efdd4c16SNicolas Vasilache LogicalResult matchAndRewrite(LinalgOp op,
16776bb7d247SNicolas Vasilache PatternRewriter &rewriter) const override {
1678efdd4c16SNicolas Vasilache FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
16796bb7d247SNicolas Vasilache if (failed(resultOrFail))
16806bb7d247SNicolas Vasilache return failure();
16816bb7d247SNicolas Vasilache Operation *newOp = *resultOrFail;
16826bb7d247SNicolas Vasilache if (newOp->getNumResults() == 0) {
1683efdd4c16SNicolas Vasilache rewriter.eraseOp(op.getOperation());
16846bb7d247SNicolas Vasilache return success();
16856bb7d247SNicolas Vasilache }
16866bb7d247SNicolas Vasilache assert(newOp->getNumResults() == 1 && "expected single result");
1687efdd4c16SNicolas Vasilache rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
16886bb7d247SNicolas Vasilache return success();
16896bb7d247SNicolas Vasilache }
16906bb7d247SNicolas Vasilache };
16916bb7d247SNicolas Vasilache
populateConvolutionVectorizationPatterns(RewritePatternSet & patterns,PatternBenefit benefit)16926bb7d247SNicolas Vasilache void mlir::linalg::populateConvolutionVectorizationPatterns(
16936bb7d247SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) {
16946bb7d247SNicolas Vasilache patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
16956bb7d247SNicolas Vasilache }
1696