199ef9eebSMatthias Springer //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements target-independent rewrites as 1->N patterns.
1099ef9eebSMatthias Springer //
1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1299ef9eebSMatthias Springer
139b5a3d14SMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
149b5a3d14SMatthias Springer
1599ef9eebSMatthias Springer #include <type_traits>
1699ef9eebSMatthias Springer
1799ef9eebSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19a75a46dbSJavier Setoain #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
2199ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
228b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
23f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h"
2499ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
259b5a3d14SMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
264db65e27SLei Zhang #include "mlir/IR/BuiltinTypes.h"
2799ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h"
2899ef9eebSMatthias Springer #include "mlir/IR/Matchers.h"
2999ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h"
3099ef9eebSMatthias Springer #include "mlir/Interfaces/VectorInterfaces.h"
3199ef9eebSMatthias Springer
3299ef9eebSMatthias Springer #include "llvm/ADT/DenseSet.h"
3399ef9eebSMatthias Springer #include "llvm/ADT/MapVector.h"
3499ef9eebSMatthias Springer #include "llvm/ADT/STLExtras.h"
3599ef9eebSMatthias Springer #include "llvm/Support/CommandLine.h"
3699ef9eebSMatthias Springer #include "llvm/Support/Debug.h"
3799ef9eebSMatthias Springer #include "llvm/Support/raw_ostream.h"
3899ef9eebSMatthias Springer
3999ef9eebSMatthias Springer #define DEBUG_TYPE "vector-to-vector"
4099ef9eebSMatthias Springer
4199ef9eebSMatthias Springer using namespace mlir;
4299ef9eebSMatthias Springer using namespace mlir::vector;
4399ef9eebSMatthias Springer
4499ef9eebSMatthias Springer // Helper to find an index in an affine map.
getResultIndex(AffineMap map,int64_t index)4599ef9eebSMatthias Springer static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
4699ef9eebSMatthias Springer for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
4799ef9eebSMatthias Springer int64_t idx = map.getDimPosition(i);
4899ef9eebSMatthias Springer if (idx == index)
4999ef9eebSMatthias Springer return i;
5099ef9eebSMatthias Springer }
5199ef9eebSMatthias Springer return None;
5299ef9eebSMatthias Springer }
5399ef9eebSMatthias Springer
5499ef9eebSMatthias Springer // Helper to construct iterator types with one index removed.
adjustIter(ArrayAttr iteratorTypes,int64_t index)5599ef9eebSMatthias Springer static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
5699ef9eebSMatthias Springer int64_t index) {
5799ef9eebSMatthias Springer SmallVector<Attribute, 4> results;
5899ef9eebSMatthias Springer for (const auto &it : llvm::enumerate(iteratorTypes)) {
5999ef9eebSMatthias Springer int64_t idx = it.index();
6099ef9eebSMatthias Springer if (idx == index)
6199ef9eebSMatthias Springer continue;
6299ef9eebSMatthias Springer results.push_back(it.value());
6399ef9eebSMatthias Springer }
6499ef9eebSMatthias Springer return results;
6599ef9eebSMatthias Springer }
6699ef9eebSMatthias Springer
6799ef9eebSMatthias Springer // Helper to construct an affine map with one index removed.
adjustMap(AffineMap map,int64_t index,PatternRewriter & rewriter)6899ef9eebSMatthias Springer static AffineMap adjustMap(AffineMap map, int64_t index,
6999ef9eebSMatthias Springer PatternRewriter &rewriter) {
7099ef9eebSMatthias Springer auto *ctx = rewriter.getContext();
7199ef9eebSMatthias Springer SmallVector<AffineExpr, 4> results;
7299ef9eebSMatthias Springer for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
7399ef9eebSMatthias Springer int64_t idx = map.getDimPosition(i);
7499ef9eebSMatthias Springer if (idx == index)
7599ef9eebSMatthias Springer continue;
7699ef9eebSMatthias Springer // Re-insert remaining indices, but renamed when occurring
7799ef9eebSMatthias Springer // after the removed index.
7899ef9eebSMatthias Springer auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
7999ef9eebSMatthias Springer results.push_back(targetExpr);
8099ef9eebSMatthias Springer }
8199ef9eebSMatthias Springer return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
8299ef9eebSMatthias Springer }
8399ef9eebSMatthias Springer
8499ef9eebSMatthias Springer // Helper method to possibly drop a dimension in a load.
8599ef9eebSMatthias Springer // TODO
reshapeLoad(Location loc,Value val,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)8699ef9eebSMatthias Springer static Value reshapeLoad(Location loc, Value val, VectorType type,
8799ef9eebSMatthias Springer int64_t index, int64_t pos,
8899ef9eebSMatthias Springer PatternRewriter &rewriter) {
8999ef9eebSMatthias Springer if (index == -1)
9099ef9eebSMatthias Springer return val;
9199ef9eebSMatthias Springer Type lowType = VectorType::Builder(type).dropDim(0);
9299ef9eebSMatthias Springer // At extraction dimension?
9399ef9eebSMatthias Springer if (index == 0) {
9499ef9eebSMatthias Springer auto posAttr = rewriter.getI64ArrayAttr(pos);
9599ef9eebSMatthias Springer return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
9699ef9eebSMatthias Springer }
9799ef9eebSMatthias Springer // Unroll leading dimensions.
9899ef9eebSMatthias Springer VectorType vType = lowType.cast<VectorType>();
9999ef9eebSMatthias Springer Type resType = VectorType::Builder(type).dropDim(index);
10099ef9eebSMatthias Springer auto resVectorType = resType.cast<VectorType>();
10199ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
10299ef9eebSMatthias Springer loc, resVectorType, rewriter.getZeroAttr(resVectorType));
10399ef9eebSMatthias Springer for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
10499ef9eebSMatthias Springer auto posAttr = rewriter.getI64ArrayAttr(d);
10599ef9eebSMatthias Springer Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
10699ef9eebSMatthias Springer Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
10799ef9eebSMatthias Springer result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
10899ef9eebSMatthias Springer posAttr);
10999ef9eebSMatthias Springer }
11099ef9eebSMatthias Springer return result;
11199ef9eebSMatthias Springer }
11299ef9eebSMatthias Springer
11399ef9eebSMatthias Springer // Helper method to possibly drop a dimension in a store.
11499ef9eebSMatthias Springer // TODO
reshapeStore(Location loc,Value val,Value result,VectorType type,int64_t index,int64_t pos,PatternRewriter & rewriter)11599ef9eebSMatthias Springer static Value reshapeStore(Location loc, Value val, Value result,
11699ef9eebSMatthias Springer VectorType type, int64_t index, int64_t pos,
11799ef9eebSMatthias Springer PatternRewriter &rewriter) {
11899ef9eebSMatthias Springer // Unmodified?
11999ef9eebSMatthias Springer if (index == -1)
12099ef9eebSMatthias Springer return val;
12199ef9eebSMatthias Springer // At insertion dimension?
12299ef9eebSMatthias Springer if (index == 0) {
12399ef9eebSMatthias Springer auto posAttr = rewriter.getI64ArrayAttr(pos);
12499ef9eebSMatthias Springer return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
12599ef9eebSMatthias Springer }
12699ef9eebSMatthias Springer // Unroll leading dimensions.
12799ef9eebSMatthias Springer Type lowType = VectorType::Builder(type).dropDim(0);
12899ef9eebSMatthias Springer VectorType vType = lowType.cast<VectorType>();
12999ef9eebSMatthias Springer Type insType = VectorType::Builder(vType).dropDim(0);
13099ef9eebSMatthias Springer for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
13199ef9eebSMatthias Springer auto posAttr = rewriter.getI64ArrayAttr(d);
13299ef9eebSMatthias Springer Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
13399ef9eebSMatthias Springer Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
13499ef9eebSMatthias Springer Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
13599ef9eebSMatthias Springer result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
13699ef9eebSMatthias Springer }
13799ef9eebSMatthias Springer return result;
13899ef9eebSMatthias Springer }
13999ef9eebSMatthias Springer
14099ef9eebSMatthias Springer template <typename IntType>
extractVector(ArrayAttr arrayAttr)14199ef9eebSMatthias Springer static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
14299ef9eebSMatthias Springer return llvm::to_vector<4>(llvm::map_range(
14399ef9eebSMatthias Springer arrayAttr.getAsRange<IntegerAttr>(),
14499ef9eebSMatthias Springer [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
14599ef9eebSMatthias Springer }
14699ef9eebSMatthias Springer
14789aaa2d0SThomas Raoux /// Helper to create arithmetic operation associated with a kind of contraction.
createContractArithOp(Location loc,Value x,Value y,Value acc,vector::CombiningKind kind,PatternRewriter & rewriter,bool isInt)14889aaa2d0SThomas Raoux static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
14989aaa2d0SThomas Raoux Value acc,
15089aaa2d0SThomas Raoux vector::CombiningKind kind,
15189aaa2d0SThomas Raoux PatternRewriter &rewriter,
15289aaa2d0SThomas Raoux bool isInt) {
15389aaa2d0SThomas Raoux using vector::CombiningKind;
15489aaa2d0SThomas Raoux Value mul;
15589aaa2d0SThomas Raoux if (isInt) {
15689aaa2d0SThomas Raoux if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
15789aaa2d0SThomas Raoux // Only valid for floating point types.
15889aaa2d0SThomas Raoux return Optional<Value>();
15989aaa2d0SThomas Raoux mul = rewriter.create<arith::MulIOp>(loc, x, y);
16089aaa2d0SThomas Raoux } else {
16189aaa2d0SThomas Raoux // Float case.
16289aaa2d0SThomas Raoux if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
16389aaa2d0SThomas Raoux kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
16489aaa2d0SThomas Raoux kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
16589aaa2d0SThomas Raoux kind == CombiningKind::XOR)
16689aaa2d0SThomas Raoux // Only valid for integer types.
16789aaa2d0SThomas Raoux return Optional<Value>();
16889aaa2d0SThomas Raoux // Special case for fused multiply-add.
16989aaa2d0SThomas Raoux if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
17089aaa2d0SThomas Raoux return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
17189aaa2d0SThomas Raoux }
17289aaa2d0SThomas Raoux mul = rewriter.create<arith::MulFOp>(loc, x, y);
17389aaa2d0SThomas Raoux }
17489aaa2d0SThomas Raoux if (!acc)
17589aaa2d0SThomas Raoux return Optional<Value>(mul);
17689aaa2d0SThomas Raoux return makeArithReduction(rewriter, loc, kind, mul, acc);
17789aaa2d0SThomas Raoux }
17889aaa2d0SThomas Raoux
17989aaa2d0SThomas Raoux /// Return the positions of the reductions in the given map.
getReductionIndex(AffineMap map,ArrayAttr iteratorTypes)18089aaa2d0SThomas Raoux static SmallVector<int64_t> getReductionIndex(AffineMap map,
18189aaa2d0SThomas Raoux ArrayAttr iteratorTypes) {
18289aaa2d0SThomas Raoux SmallVector<int64_t> dimsIdx;
18389aaa2d0SThomas Raoux for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
18489aaa2d0SThomas Raoux if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
18589aaa2d0SThomas Raoux dimsIdx.push_back(i);
18689aaa2d0SThomas Raoux }
18789aaa2d0SThomas Raoux return dimsIdx;
18889aaa2d0SThomas Raoux }
18989aaa2d0SThomas Raoux
19089aaa2d0SThomas Raoux /// Look for a given dimension in an affine map and return its position. Return
19189aaa2d0SThomas Raoux /// llvm::None if the dimension is not in the map results.
getDimPosition(AffineMap map,unsigned dim)19289aaa2d0SThomas Raoux static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
19389aaa2d0SThomas Raoux for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
19489aaa2d0SThomas Raoux if (map.getDimPosition(i) == dim)
19589aaa2d0SThomas Raoux return i;
19689aaa2d0SThomas Raoux }
19789aaa2d0SThomas Raoux return llvm::None;
19889aaa2d0SThomas Raoux }
19989aaa2d0SThomas Raoux
20099ef9eebSMatthias Springer namespace {
20199ef9eebSMatthias Springer
20299ef9eebSMatthias Springer /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
20399ef9eebSMatthias Springer //
20499ef9eebSMatthias Springer // Example:
20599ef9eebSMatthias Springer //
20699ef9eebSMatthias Springer // The following MLIR with cancelling ShapeCastOps:
20799ef9eebSMatthias Springer //
20899ef9eebSMatthias Springer // %0 = source : vector<5x4x2xf32>
20999ef9eebSMatthias Springer // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
21099ef9eebSMatthias Springer // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
21199ef9eebSMatthias Springer // %3 = user %2 : vector<5x4x2xf32>
21299ef9eebSMatthias Springer //
21399ef9eebSMatthias Springer // Should canonicalize to the following:
21499ef9eebSMatthias Springer //
21599ef9eebSMatthias Springer // %0 = source : vector<5x4x2xf32>
21699ef9eebSMatthias Springer // %1 = user %0 : vector<5x4x2xf32>
21799ef9eebSMatthias Springer //
21899ef9eebSMatthias Springer struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
21999ef9eebSMatthias Springer using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
22099ef9eebSMatthias Springer
matchAndRewrite__anon5c5a5b800211::ShapeCastOpFolder22199ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
22299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
22399ef9eebSMatthias Springer // Check if 'shapeCastOp' has vector source/result type.
22499ef9eebSMatthias Springer auto sourceVectorType =
2257c38fd60SJacques Pienaar shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
22699ef9eebSMatthias Springer auto resultVectorType =
2277c38fd60SJacques Pienaar shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
22899ef9eebSMatthias Springer if (!sourceVectorType || !resultVectorType)
22999ef9eebSMatthias Springer return failure();
23099ef9eebSMatthias Springer
23199ef9eebSMatthias Springer // Check if shape cast op source operand is also a shape cast op.
23299ef9eebSMatthias Springer auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
2337c38fd60SJacques Pienaar shapeCastOp.getSource().getDefiningOp());
23499ef9eebSMatthias Springer if (!sourceShapeCastOp)
23599ef9eebSMatthias Springer return failure();
23699ef9eebSMatthias Springer auto operandSourceVectorType =
2377c38fd60SJacques Pienaar sourceShapeCastOp.getSource().getType().cast<VectorType>();
23899ef9eebSMatthias Springer auto operandResultVectorType = sourceShapeCastOp.getType();
23999ef9eebSMatthias Springer
24099ef9eebSMatthias Springer // Check if shape cast operations invert each other.
24199ef9eebSMatthias Springer if (operandSourceVectorType != resultVectorType ||
24299ef9eebSMatthias Springer operandResultVectorType != sourceVectorType)
24399ef9eebSMatthias Springer return failure();
24499ef9eebSMatthias Springer
2457c38fd60SJacques Pienaar rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
24699ef9eebSMatthias Springer return success();
24799ef9eebSMatthias Springer }
24899ef9eebSMatthias Springer };
24999ef9eebSMatthias Springer
25099ef9eebSMatthias Springer /// Progressive lowering of BroadcastOp.
25199ef9eebSMatthias Springer class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
25299ef9eebSMatthias Springer public:
25399ef9eebSMatthias Springer using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
25499ef9eebSMatthias Springer
matchAndRewrite(vector::BroadcastOp op,PatternRewriter & rewriter) const25599ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::BroadcastOp op,
25699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
25799ef9eebSMatthias Springer auto loc = op.getLoc();
25899ef9eebSMatthias Springer VectorType dstType = op.getVectorType();
25999ef9eebSMatthias Springer VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
26099ef9eebSMatthias Springer Type eltType = dstType.getElementType();
26199ef9eebSMatthias Springer
26299ef9eebSMatthias Springer // Scalar to any vector can use splat.
26399ef9eebSMatthias Springer if (!srcType) {
2647c38fd60SJacques Pienaar rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
26599ef9eebSMatthias Springer return success();
26699ef9eebSMatthias Springer }
26799ef9eebSMatthias Springer
26899ef9eebSMatthias Springer // Determine rank of source and destination.
26999ef9eebSMatthias Springer int64_t srcRank = srcType.getRank();
27099ef9eebSMatthias Springer int64_t dstRank = dstType.getRank();
27199ef9eebSMatthias Springer
27299ef9eebSMatthias Springer // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
27399ef9eebSMatthias Springer if (srcRank <= 1 && dstRank == 1) {
27499ef9eebSMatthias Springer Value ext;
27599ef9eebSMatthias Springer if (srcRank == 0)
2767c38fd60SJacques Pienaar ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
27799ef9eebSMatthias Springer else
2787c38fd60SJacques Pienaar ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
2796a8ba318SRiver Riddle rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
28099ef9eebSMatthias Springer return success();
28199ef9eebSMatthias Springer }
28299ef9eebSMatthias Springer
28399ef9eebSMatthias Springer // Duplicate this rank.
28499ef9eebSMatthias Springer // For example:
28599ef9eebSMatthias Springer // %x = broadcast %y : k-D to n-D, k < n
28699ef9eebSMatthias Springer // becomes:
28799ef9eebSMatthias Springer // %b = broadcast %y : k-D to (n-1)-D
28899ef9eebSMatthias Springer // %x = [%b,%b,%b,%b] : n-D
28999ef9eebSMatthias Springer // becomes:
29099ef9eebSMatthias Springer // %b = [%y,%y] : (n-1)-D
29199ef9eebSMatthias Springer // %x = [%b,%b,%b,%b] : n-D
29299ef9eebSMatthias Springer if (srcRank < dstRank) {
29399ef9eebSMatthias Springer // Duplication.
29499ef9eebSMatthias Springer VectorType resType =
29599ef9eebSMatthias Springer VectorType::get(dstType.getShape().drop_front(), eltType);
29699ef9eebSMatthias Springer Value bcst =
2977c38fd60SJacques Pienaar rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
29899ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
29999ef9eebSMatthias Springer loc, dstType, rewriter.getZeroAttr(dstType));
30099ef9eebSMatthias Springer for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
30199ef9eebSMatthias Springer result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
30299ef9eebSMatthias Springer rewriter.replaceOp(op, result);
30399ef9eebSMatthias Springer return success();
30499ef9eebSMatthias Springer }
30599ef9eebSMatthias Springer
30699ef9eebSMatthias Springer // Find non-matching dimension, if any.
30799ef9eebSMatthias Springer assert(srcRank == dstRank);
30899ef9eebSMatthias Springer int64_t m = -1;
30999ef9eebSMatthias Springer for (int64_t r = 0; r < dstRank; r++)
31099ef9eebSMatthias Springer if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
31199ef9eebSMatthias Springer m = r;
31299ef9eebSMatthias Springer break;
31399ef9eebSMatthias Springer }
31499ef9eebSMatthias Springer
31599ef9eebSMatthias Springer // All trailing dimensions are the same. Simply pass through.
31699ef9eebSMatthias Springer if (m == -1) {
3177c38fd60SJacques Pienaar rewriter.replaceOp(op, op.getSource());
31899ef9eebSMatthias Springer return success();
31999ef9eebSMatthias Springer }
32099ef9eebSMatthias Springer
32199ef9eebSMatthias Springer // Any non-matching dimension forces a stretch along this rank.
32299ef9eebSMatthias Springer // For example:
32399ef9eebSMatthias Springer // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
32499ef9eebSMatthias Springer // becomes:
32599ef9eebSMatthias Springer // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
32699ef9eebSMatthias Springer // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
32799ef9eebSMatthias Springer // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
32899ef9eebSMatthias Springer // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
32999ef9eebSMatthias Springer // %x = [%a,%b,%c,%d]
33099ef9eebSMatthias Springer // becomes:
33199ef9eebSMatthias Springer // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
33299ef9eebSMatthias Springer // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
33399ef9eebSMatthias Springer // %a = [%u, %v]
33499ef9eebSMatthias Springer // ..
33599ef9eebSMatthias Springer // %x = [%a,%b,%c,%d]
33699ef9eebSMatthias Springer VectorType resType =
33799ef9eebSMatthias Springer VectorType::get(dstType.getShape().drop_front(), eltType);
33899ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
33999ef9eebSMatthias Springer loc, dstType, rewriter.getZeroAttr(dstType));
34099ef9eebSMatthias Springer if (m == 0) {
34199ef9eebSMatthias Springer // Stetch at start.
3427c38fd60SJacques Pienaar Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
34399ef9eebSMatthias Springer Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
34499ef9eebSMatthias Springer for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
34599ef9eebSMatthias Springer result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
34699ef9eebSMatthias Springer } else {
34799ef9eebSMatthias Springer // Stetch not at start.
34899ef9eebSMatthias Springer for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
3497c38fd60SJacques Pienaar Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
35099ef9eebSMatthias Springer Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
35199ef9eebSMatthias Springer result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
35299ef9eebSMatthias Springer }
35399ef9eebSMatthias Springer }
35499ef9eebSMatthias Springer rewriter.replaceOp(op, result);
35599ef9eebSMatthias Springer return success();
35699ef9eebSMatthias Springer }
35799ef9eebSMatthias Springer };
35899ef9eebSMatthias Springer
359f71f9958SDiego Caballero /// Given a 'transpose' pattern, prune the rightmost dimensions that are not
360f71f9958SDiego Caballero /// transposed.
pruneNonTransposedDims(ArrayRef<int64_t> transpose,SmallVectorImpl<int64_t> & result)361f71f9958SDiego Caballero void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
362f71f9958SDiego Caballero SmallVectorImpl<int64_t> &result) {
363917d95fcSDiego Caballero size_t numTransposedDims = transpose.size();
364917d95fcSDiego Caballero for (size_t transpDim : llvm::reverse(transpose)) {
365917d95fcSDiego Caballero if (transpDim != numTransposedDims - 1)
366917d95fcSDiego Caballero break;
367917d95fcSDiego Caballero numTransposedDims--;
368917d95fcSDiego Caballero }
369f71f9958SDiego Caballero
370f71f9958SDiego Caballero result.append(transpose.begin(), transpose.begin() + numTransposedDims);
371917d95fcSDiego Caballero }
372917d95fcSDiego Caballero
37399ef9eebSMatthias Springer /// Progressive lowering of TransposeOp.
37499ef9eebSMatthias Springer /// One:
37599ef9eebSMatthias Springer /// %x = vector.transpose %y, [1, 0]
37699ef9eebSMatthias Springer /// is replaced by:
37799ef9eebSMatthias Springer /// %z = arith.constant dense<0.000000e+00>
37899ef9eebSMatthias Springer /// %0 = vector.extract %y[0, 0]
37999ef9eebSMatthias Springer /// %1 = vector.insert %0, %z [0, 0]
38099ef9eebSMatthias Springer /// ..
38199ef9eebSMatthias Springer /// %x = vector.insert .., .. [.., ..]
38299ef9eebSMatthias Springer class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
38399ef9eebSMatthias Springer public:
38499ef9eebSMatthias Springer using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
38599ef9eebSMatthias Springer
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context)38699ef9eebSMatthias Springer TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
38799ef9eebSMatthias Springer MLIRContext *context)
38899ef9eebSMatthias Springer : OpRewritePattern<vector::TransposeOp>(context),
38999ef9eebSMatthias Springer vectorTransformOptions(vectorTransformOptions) {}
39099ef9eebSMatthias Springer
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const39199ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransposeOp op,
39299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
39399ef9eebSMatthias Springer auto loc = op.getLoc();
39499ef9eebSMatthias Springer
3957c38fd60SJacques Pienaar Value input = op.getVector();
396f71f9958SDiego Caballero VectorType inputType = op.getVectorType();
39799ef9eebSMatthias Springer VectorType resType = op.getResultType();
39899ef9eebSMatthias Springer
39999ef9eebSMatthias Springer // Set up convenience transposition table.
40099ef9eebSMatthias Springer SmallVector<int64_t, 4> transp;
4017c38fd60SJacques Pienaar for (auto attr : op.getTransp())
40299ef9eebSMatthias Springer transp.push_back(attr.cast<IntegerAttr>().getInt());
40399ef9eebSMatthias Springer
40499ef9eebSMatthias Springer if (vectorTransformOptions.vectorTransposeLowering ==
40599ef9eebSMatthias Springer vector::VectorTransposeLowering::Shuffle &&
40699ef9eebSMatthias Springer resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
40799ef9eebSMatthias Springer return rewriter.notifyMatchFailure(
40899ef9eebSMatthias Springer op, "Options specifies lowering to shuffle");
40999ef9eebSMatthias Springer
41099ef9eebSMatthias Springer // Handle a true 2-D matrix transpose differently when requested.
41199ef9eebSMatthias Springer if (vectorTransformOptions.vectorTransposeLowering ==
41299ef9eebSMatthias Springer vector::VectorTransposeLowering::Flat &&
41399ef9eebSMatthias Springer resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
41499ef9eebSMatthias Springer Type flattenedType =
41599ef9eebSMatthias Springer VectorType::get(resType.getNumElements(), resType.getElementType());
41699ef9eebSMatthias Springer auto matrix =
417f71f9958SDiego Caballero rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
41899ef9eebSMatthias Springer auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
41999ef9eebSMatthias Springer auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
42099ef9eebSMatthias Springer Value trans = rewriter.create<vector::FlatTransposeOp>(
42199ef9eebSMatthias Springer loc, flattenedType, matrix, rows, columns);
42299ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
42399ef9eebSMatthias Springer return success();
42499ef9eebSMatthias Springer }
42599ef9eebSMatthias Springer
426917d95fcSDiego Caballero // Generate unrolled extract/insert ops. We do not unroll the rightmost
427917d95fcSDiego Caballero // (i.e., highest-order) dimensions that are not transposed and leave them
428f71f9958SDiego Caballero // in vector form to improve performance. Therefore, we prune those
429f71f9958SDiego Caballero // dimensions from the shape/transpose data structures used to generate the
430f71f9958SDiego Caballero // extract/insert ops.
431f71f9958SDiego Caballero SmallVector<int64_t, 4> prunedTransp;
432f71f9958SDiego Caballero pruneNonTransposedDims(transp, prunedTransp);
433f71f9958SDiego Caballero size_t numPrunedDims = transp.size() - prunedTransp.size();
434f71f9958SDiego Caballero auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
435f71f9958SDiego Caballero SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
436f71f9958SDiego Caballero auto prunedInStrides = computeStrides(prunedInShape, ones);
437917d95fcSDiego Caballero
438f71f9958SDiego Caballero // Generates the extract/insert operations for every scalar/vector element
439f71f9958SDiego Caballero // of the leftmost transposed dimensions. We traverse every transpose
440f71f9958SDiego Caballero // element using a linearized index that we delinearize to generate the
441f71f9958SDiego Caballero // appropriate indices for the extract/insert operations.
44299ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
44399ef9eebSMatthias Springer loc, resType, rewriter.getZeroAttr(resType));
444f71f9958SDiego Caballero int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
445f71f9958SDiego Caballero
446f71f9958SDiego Caballero for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
447f71f9958SDiego Caballero ++linearIdx) {
448f71f9958SDiego Caballero auto extractIdxs = delinearize(prunedInStrides, linearIdx);
449f71f9958SDiego Caballero SmallVector<int64_t, 4> insertIdxs(extractIdxs);
450f71f9958SDiego Caballero applyPermutationToVector(insertIdxs, prunedTransp);
451f71f9958SDiego Caballero Value extractOp =
452f71f9958SDiego Caballero rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
453f71f9958SDiego Caballero result =
454f71f9958SDiego Caballero rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
455f71f9958SDiego Caballero }
456f71f9958SDiego Caballero
457f71f9958SDiego Caballero rewriter.replaceOp(op, result);
45899ef9eebSMatthias Springer return success();
45999ef9eebSMatthias Springer }
46099ef9eebSMatthias Springer
46199ef9eebSMatthias Springer private:
46299ef9eebSMatthias Springer /// Options to control the vector patterns.
46399ef9eebSMatthias Springer vector::VectorTransformsOptions vectorTransformOptions;
46499ef9eebSMatthias Springer };
46599ef9eebSMatthias Springer
46699ef9eebSMatthias Springer /// Rewrite a 2-D vector.transpose as a sequence of:
46799ef9eebSMatthias Springer /// vector.shape_cast 2D -> 1D
46899ef9eebSMatthias Springer /// vector.shuffle
46999ef9eebSMatthias Springer /// vector.shape_cast 1D -> 2D
47099ef9eebSMatthias Springer class TransposeOp2DToShuffleLowering
47199ef9eebSMatthias Springer : public OpRewritePattern<vector::TransposeOp> {
47299ef9eebSMatthias Springer public:
47399ef9eebSMatthias Springer using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
47499ef9eebSMatthias Springer
TransposeOp2DToShuffleLowering(vector::VectorTransformsOptions vectorTransformOptions,MLIRContext * context)47599ef9eebSMatthias Springer TransposeOp2DToShuffleLowering(
47699ef9eebSMatthias Springer vector::VectorTransformsOptions vectorTransformOptions,
47799ef9eebSMatthias Springer MLIRContext *context)
47899ef9eebSMatthias Springer : OpRewritePattern<vector::TransposeOp>(context),
47999ef9eebSMatthias Springer vectorTransformOptions(vectorTransformOptions) {}
48099ef9eebSMatthias Springer
matchAndRewrite(vector::TransposeOp op,PatternRewriter & rewriter) const48199ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransposeOp op,
48299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
48399ef9eebSMatthias Springer auto loc = op.getLoc();
48499ef9eebSMatthias Springer
48599ef9eebSMatthias Springer VectorType srcType = op.getVectorType();
48699ef9eebSMatthias Springer if (srcType.getRank() != 2)
48799ef9eebSMatthias Springer return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
48899ef9eebSMatthias Springer
48999ef9eebSMatthias Springer SmallVector<int64_t, 4> transp;
4907c38fd60SJacques Pienaar for (auto attr : op.getTransp())
49199ef9eebSMatthias Springer transp.push_back(attr.cast<IntegerAttr>().getInt());
49299ef9eebSMatthias Springer if (transp[0] != 1 && transp[1] != 0)
49399ef9eebSMatthias Springer return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
49499ef9eebSMatthias Springer
49599ef9eebSMatthias Springer if (vectorTransformOptions.vectorTransposeLowering !=
49699ef9eebSMatthias Springer VectorTransposeLowering::Shuffle)
49799ef9eebSMatthias Springer return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
49899ef9eebSMatthias Springer
49999ef9eebSMatthias Springer int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
50099ef9eebSMatthias Springer Value casted = rewriter.create<vector::ShapeCastOp>(
5017c38fd60SJacques Pienaar loc, VectorType::get({m * n}, srcType.getElementType()),
5027c38fd60SJacques Pienaar op.getVector());
50399ef9eebSMatthias Springer SmallVector<int64_t> mask;
50499ef9eebSMatthias Springer mask.reserve(m * n);
50599ef9eebSMatthias Springer for (int64_t j = 0; j < n; ++j)
50699ef9eebSMatthias Springer for (int64_t i = 0; i < m; ++i)
50799ef9eebSMatthias Springer mask.push_back(i * n + j);
50899ef9eebSMatthias Springer
50999ef9eebSMatthias Springer Value shuffled =
51099ef9eebSMatthias Springer rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
51199ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
51299ef9eebSMatthias Springer shuffled);
51399ef9eebSMatthias Springer
51499ef9eebSMatthias Springer return success();
51599ef9eebSMatthias Springer }
51699ef9eebSMatthias Springer
51799ef9eebSMatthias Springer private:
51899ef9eebSMatthias Springer /// Options to control the vector patterns.
51999ef9eebSMatthias Springer vector::VectorTransformsOptions vectorTransformOptions;
52099ef9eebSMatthias Springer };
52199ef9eebSMatthias Springer
52299ef9eebSMatthias Springer /// Progressive lowering of OuterProductOp.
52399ef9eebSMatthias Springer /// One:
52499ef9eebSMatthias Springer /// %x = vector.outerproduct %lhs, %rhs, %acc
52599ef9eebSMatthias Springer /// is replaced by:
52699ef9eebSMatthias Springer /// %z = zero-result
52799ef9eebSMatthias Springer /// %0 = vector.extract %lhs[0]
52899ef9eebSMatthias Springer /// %1 = vector.broadcast %0
52999ef9eebSMatthias Springer /// %2 = vector.extract %acc[0]
53099ef9eebSMatthias Springer /// %3 = vector.fma %1, %rhs, %2
53199ef9eebSMatthias Springer /// %4 = vector.insert %3, %z[0]
53299ef9eebSMatthias Springer /// ..
53399ef9eebSMatthias Springer /// %x = vector.insert %.., %..[N-1]
53499ef9eebSMatthias Springer ///
53599ef9eebSMatthias Springer class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
53699ef9eebSMatthias Springer public:
53799ef9eebSMatthias Springer using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
53899ef9eebSMatthias Springer
matchAndRewrite(vector::OuterProductOp op,PatternRewriter & rewriter) const53999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::OuterProductOp op,
54099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
54199ef9eebSMatthias Springer auto loc = op.getLoc();
54299ef9eebSMatthias Springer
54399ef9eebSMatthias Springer VectorType lhsType = op.getOperandVectorTypeLHS();
54499ef9eebSMatthias Springer VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
54599ef9eebSMatthias Springer VectorType resType = op.getVectorType();
54699ef9eebSMatthias Springer Type eltType = resType.getElementType();
54799ef9eebSMatthias Springer bool isInt = eltType.isa<IntegerType, IndexType>();
5487c38fd60SJacques Pienaar Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
5497c38fd60SJacques Pienaar vector::CombiningKind kind = op.getKind();
55099ef9eebSMatthias Springer
55199ef9eebSMatthias Springer if (!rhsType) {
55299ef9eebSMatthias Springer // Special case: AXPY operation.
5537c38fd60SJacques Pienaar Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
55489aaa2d0SThomas Raoux Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
55589aaa2d0SThomas Raoux kind, rewriter, isInt);
556491d2701SKazu Hirata if (!mult.has_value())
55799ef9eebSMatthias Springer return failure();
558c27d8152SKazu Hirata rewriter.replaceOp(op, mult.value());
55999ef9eebSMatthias Springer return success();
56099ef9eebSMatthias Springer }
56199ef9eebSMatthias Springer
56299ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
56399ef9eebSMatthias Springer loc, resType, rewriter.getZeroAttr(resType));
56499ef9eebSMatthias Springer for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
56599ef9eebSMatthias Springer auto pos = rewriter.getI64ArrayAttr(d);
5667c38fd60SJacques Pienaar Value x =
5677c38fd60SJacques Pienaar rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
56899ef9eebSMatthias Springer Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
56999ef9eebSMatthias Springer Value r = nullptr;
57099ef9eebSMatthias Springer if (acc)
57199ef9eebSMatthias Springer r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
5727c38fd60SJacques Pienaar Optional<Value> m =
57389aaa2d0SThomas Raoux createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
574491d2701SKazu Hirata if (!m.has_value())
57599ef9eebSMatthias Springer return failure();
576c27d8152SKazu Hirata result = rewriter.create<vector::InsertOp>(loc, resType, m.value(),
57799ef9eebSMatthias Springer result, pos);
57899ef9eebSMatthias Springer }
57999ef9eebSMatthias Springer rewriter.replaceOp(op, result);
58099ef9eebSMatthias Springer return success();
58199ef9eebSMatthias Springer }
58289aaa2d0SThomas Raoux };
58389aaa2d0SThomas Raoux
58489aaa2d0SThomas Raoux /// Lower vector.contract with all size one reduction dimensions to
58589aaa2d0SThomas Raoux /// elementwise ops when possible.
58689aaa2d0SThomas Raoux struct ContractOpToElementwise
58789aaa2d0SThomas Raoux : public OpRewritePattern<vector::ContractionOp> {
58889aaa2d0SThomas Raoux using OpRewritePattern::OpRewritePattern;
58989aaa2d0SThomas Raoux using FilterConstraintType =
59089aaa2d0SThomas Raoux std::function<LogicalResult(vector::ContractionOp op)>;
defaultFilter__anon5c5a5b800211::ContractOpToElementwise59189aaa2d0SThomas Raoux static LogicalResult defaultFilter(vector::ContractionOp op) {
59289aaa2d0SThomas Raoux return success();
59389aaa2d0SThomas Raoux }
ContractOpToElementwise__anon5c5a5b800211::ContractOpToElementwise59489aaa2d0SThomas Raoux ContractOpToElementwise(
59589aaa2d0SThomas Raoux vector::VectorTransformsOptions vectorTransformOptions,
59689aaa2d0SThomas Raoux MLIRContext *context,
59789aaa2d0SThomas Raoux const FilterConstraintType &constraint = defaultFilter)
59889aaa2d0SThomas Raoux : OpRewritePattern<vector::ContractionOp>(context),
59989aaa2d0SThomas Raoux vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
60089aaa2d0SThomas Raoux
matchAndRewrite__anon5c5a5b800211::ContractOpToElementwise60189aaa2d0SThomas Raoux LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
60289aaa2d0SThomas Raoux PatternRewriter &rewriter) const override {
60389aaa2d0SThomas Raoux // TODO: implement masks
60489aaa2d0SThomas Raoux if (llvm::size(contractOp.getMasks()) != 0)
60589aaa2d0SThomas Raoux return failure();
60689aaa2d0SThomas Raoux
60789aaa2d0SThomas Raoux if (failed(filter(contractOp)))
60889aaa2d0SThomas Raoux return failure();
60989aaa2d0SThomas Raoux
61089aaa2d0SThomas Raoux if (vectorTransformOptions.vectorContractLowering !=
61189aaa2d0SThomas Raoux vector::VectorContractLowering::ParallelArith)
61289aaa2d0SThomas Raoux return failure();
61389aaa2d0SThomas Raoux ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
61489aaa2d0SThomas Raoux ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
615*d2c0572bSJacques Pienaar AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
616*d2c0572bSJacques Pienaar AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
61789aaa2d0SThomas Raoux SmallVector<int64_t> lhsReductionDims =
61889aaa2d0SThomas Raoux getReductionIndex(lhsMap, contractOp.getIteratorTypes());
61989aaa2d0SThomas Raoux SmallVector<int64_t> rhsReductionDims =
62089aaa2d0SThomas Raoux getReductionIndex(rhsMap, contractOp.getIteratorTypes());
62189aaa2d0SThomas Raoux // All the reduction dimensions must be a size 1.
62289aaa2d0SThomas Raoux for (int64_t dim : lhsReductionDims) {
62389aaa2d0SThomas Raoux if (lhsShape[dim] != 1)
62489aaa2d0SThomas Raoux return failure();
62589aaa2d0SThomas Raoux }
62689aaa2d0SThomas Raoux for (int64_t dim : rhsReductionDims) {
62789aaa2d0SThomas Raoux if (rhsShape[dim] != 1)
62889aaa2d0SThomas Raoux return failure();
62989aaa2d0SThomas Raoux }
630*d2c0572bSJacques Pienaar AffineMap accMap = contractOp.getIndexingMapsArray()[2];
63189aaa2d0SThomas Raoux unsigned numParallelDims = accMap.getNumResults();
63289aaa2d0SThomas Raoux unsigned numLhsDimToBroadcast =
63389aaa2d0SThomas Raoux numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
63489aaa2d0SThomas Raoux unsigned numRhsDimToBroadcast =
63589aaa2d0SThomas Raoux numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
63689aaa2d0SThomas Raoux SmallVector<int64_t> lhsDims;
63789aaa2d0SThomas Raoux SmallVector<int64_t> lhsTranspose;
63889aaa2d0SThomas Raoux SmallVector<int64_t> rhsDims;
63989aaa2d0SThomas Raoux SmallVector<int64_t> rhsTranspose;
64089aaa2d0SThomas Raoux for (int64_t dim : lhsReductionDims)
64189aaa2d0SThomas Raoux lhsTranspose.push_back(numLhsDimToBroadcast + dim);
64289aaa2d0SThomas Raoux for (int64_t dim : rhsReductionDims)
64389aaa2d0SThomas Raoux rhsTranspose.push_back(numRhsDimToBroadcast + dim);
64489aaa2d0SThomas Raoux // Loop through the parallel dimensions to calculate the dimensions to
64589aaa2d0SThomas Raoux // broadcast and to permute in order to extract only parallel dimensions.
64689aaa2d0SThomas Raoux for (unsigned i = 0; i < numParallelDims; i++) {
64789aaa2d0SThomas Raoux llvm::Optional<unsigned> lhsDim =
64889aaa2d0SThomas Raoux getDimPosition(lhsMap, accMap.getDimPosition(i));
64989aaa2d0SThomas Raoux if (lhsDim) {
65089aaa2d0SThomas Raoux lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
65189aaa2d0SThomas Raoux } else {
65289aaa2d0SThomas Raoux // If the parallel dimension doesn't exist we will have to broadcast it.
65389aaa2d0SThomas Raoux lhsDims.push_back(
65489aaa2d0SThomas Raoux contractOp.getResultType().cast<VectorType>().getDimSize(i));
65589aaa2d0SThomas Raoux lhsTranspose.push_back(lhsDims.size() - 1);
65689aaa2d0SThomas Raoux }
65789aaa2d0SThomas Raoux llvm::Optional<unsigned> rhsDim =
65889aaa2d0SThomas Raoux getDimPosition(rhsMap, accMap.getDimPosition(i));
65989aaa2d0SThomas Raoux if (rhsDim) {
66089aaa2d0SThomas Raoux rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
66189aaa2d0SThomas Raoux } else {
66289aaa2d0SThomas Raoux // If the parallel dimension doesn't exist we will have to broadcast it.
66389aaa2d0SThomas Raoux rhsDims.push_back(
66489aaa2d0SThomas Raoux contractOp.getResultType().cast<VectorType>().getDimSize(i));
66589aaa2d0SThomas Raoux rhsTranspose.push_back(rhsDims.size() - 1);
66689aaa2d0SThomas Raoux }
66789aaa2d0SThomas Raoux }
66889aaa2d0SThomas Raoux Value newLhs = contractOp.getLhs();
66989aaa2d0SThomas Raoux Value newRhs = contractOp.getRhs();
67089aaa2d0SThomas Raoux Location loc = contractOp.getLoc();
67189aaa2d0SThomas Raoux if (!lhsDims.empty()) {
67289aaa2d0SThomas Raoux lhsDims.append(lhsShape.begin(), lhsShape.end());
67389aaa2d0SThomas Raoux auto expandedType =
67489aaa2d0SThomas Raoux VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
67589aaa2d0SThomas Raoux newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
67689aaa2d0SThomas Raoux }
67789aaa2d0SThomas Raoux if (!rhsDims.empty()) {
67889aaa2d0SThomas Raoux rhsDims.append(rhsShape.begin(), rhsShape.end());
67989aaa2d0SThomas Raoux auto expandedType =
68089aaa2d0SThomas Raoux VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
68189aaa2d0SThomas Raoux newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
68289aaa2d0SThomas Raoux }
68389aaa2d0SThomas Raoux bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
68489aaa2d0SThomas Raoux newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
68589aaa2d0SThomas Raoux newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
68689aaa2d0SThomas Raoux SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
68789aaa2d0SThomas Raoux SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
68889aaa2d0SThomas Raoux newLhs = rewriter.create<vector::ExtractOp>(
68989aaa2d0SThomas Raoux loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
69089aaa2d0SThomas Raoux newRhs = rewriter.create<vector::ExtractOp>(
69189aaa2d0SThomas Raoux loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
69289aaa2d0SThomas Raoux Optional<Value> result =
69389aaa2d0SThomas Raoux createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
69489aaa2d0SThomas Raoux contractOp.getKind(), rewriter, isInt);
69589aaa2d0SThomas Raoux rewriter.replaceOp(contractOp, {*result});
69689aaa2d0SThomas Raoux return success();
69789aaa2d0SThomas Raoux }
69899ef9eebSMatthias Springer
69999ef9eebSMatthias Springer private:
70089aaa2d0SThomas Raoux /// Options to control the vector patterns.
70189aaa2d0SThomas Raoux vector::VectorTransformsOptions vectorTransformOptions;
70289aaa2d0SThomas Raoux FilterConstraintType filter;
70399ef9eebSMatthias Springer };
70499ef9eebSMatthias Springer
70599ef9eebSMatthias Springer /// Progressive lowering of ConstantMaskOp.
70699ef9eebSMatthias Springer /// One:
70799ef9eebSMatthias Springer /// %x = vector.constant_mask [a,b]
70899ef9eebSMatthias Springer /// is replaced by:
70999ef9eebSMatthias Springer /// %z = zero-result
71099ef9eebSMatthias Springer /// %l = vector.constant_mask [b]
71199ef9eebSMatthias Springer /// %4 = vector.insert %l, %z[0]
71299ef9eebSMatthias Springer /// ..
71399ef9eebSMatthias Springer /// %x = vector.insert %l, %..[a-1]
71499ef9eebSMatthias Springer /// until a one-dimensional vector is reached. All these operations
71599ef9eebSMatthias Springer /// will be folded at LLVM IR level.
71699ef9eebSMatthias Springer class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
71799ef9eebSMatthias Springer public:
71899ef9eebSMatthias Springer using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
71999ef9eebSMatthias Springer
matchAndRewrite(vector::ConstantMaskOp op,PatternRewriter & rewriter) const72099ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
72199ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
72299ef9eebSMatthias Springer auto loc = op.getLoc();
72399ef9eebSMatthias Springer auto dstType = op.getType();
72499ef9eebSMatthias Springer auto eltType = dstType.getElementType();
7257c38fd60SJacques Pienaar auto dimSizes = op.getMaskDimSizes();
72699ef9eebSMatthias Springer int64_t rank = dstType.getRank();
72799ef9eebSMatthias Springer
72899ef9eebSMatthias Springer if (rank == 0) {
72999ef9eebSMatthias Springer assert(dimSizes.size() == 1 &&
73099ef9eebSMatthias Springer "Expected exactly one dim size for a 0-D vector");
73199ef9eebSMatthias Springer bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
73299ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<arith::ConstantOp>(
73399ef9eebSMatthias Springer op, dstType,
73499ef9eebSMatthias Springer DenseIntElementsAttr::get(
73599ef9eebSMatthias Springer VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
73699ef9eebSMatthias Springer ArrayRef<bool>{value}));
73799ef9eebSMatthias Springer return success();
73899ef9eebSMatthias Springer }
73999ef9eebSMatthias Springer
740a75a46dbSJavier Setoain // Scalable constant masks can only be lowered for the "none set" case.
741a75a46dbSJavier Setoain if (dstType.cast<VectorType>().isScalable()) {
742a75a46dbSJavier Setoain rewriter.replaceOpWithNewOp<arith::ConstantOp>(
743a75a46dbSJavier Setoain op, DenseElementsAttr::get(dstType, false));
744a75a46dbSJavier Setoain return success();
745a75a46dbSJavier Setoain }
746a75a46dbSJavier Setoain
74799ef9eebSMatthias Springer int64_t trueDim = std::min(dstType.getDimSize(0),
74899ef9eebSMatthias Springer dimSizes[0].cast<IntegerAttr>().getInt());
74999ef9eebSMatthias Springer
75099ef9eebSMatthias Springer if (rank == 1) {
75199ef9eebSMatthias Springer // Express constant 1-D case in explicit vector form:
75299ef9eebSMatthias Springer // [T,..,T,F,..,F].
75399ef9eebSMatthias Springer SmallVector<bool, 4> values(dstType.getDimSize(0));
75499ef9eebSMatthias Springer for (int64_t d = 0; d < trueDim; d++)
75599ef9eebSMatthias Springer values[d] = true;
75699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<arith::ConstantOp>(
75799ef9eebSMatthias Springer op, dstType, rewriter.getBoolVectorAttr(values));
75899ef9eebSMatthias Springer return success();
75999ef9eebSMatthias Springer }
76099ef9eebSMatthias Springer
76199ef9eebSMatthias Springer VectorType lowType =
76299ef9eebSMatthias Springer VectorType::get(dstType.getShape().drop_front(), eltType);
76399ef9eebSMatthias Springer SmallVector<int64_t, 4> newDimSizes;
76499ef9eebSMatthias Springer for (int64_t r = 1; r < rank; r++)
76599ef9eebSMatthias Springer newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
76699ef9eebSMatthias Springer Value trueVal = rewriter.create<vector::ConstantMaskOp>(
76799ef9eebSMatthias Springer loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
76899ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
76999ef9eebSMatthias Springer loc, dstType, rewriter.getZeroAttr(dstType));
77099ef9eebSMatthias Springer for (int64_t d = 0; d < trueDim; d++) {
77199ef9eebSMatthias Springer auto pos = rewriter.getI64ArrayAttr(d);
77299ef9eebSMatthias Springer result =
77399ef9eebSMatthias Springer rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
77499ef9eebSMatthias Springer }
77599ef9eebSMatthias Springer rewriter.replaceOp(op, result);
77699ef9eebSMatthias Springer return success();
77799ef9eebSMatthias Springer }
77899ef9eebSMatthias Springer };
77999ef9eebSMatthias Springer
78099ef9eebSMatthias Springer /// Progressive lowering of CreateMaskOp.
78199ef9eebSMatthias Springer /// One:
78299ef9eebSMatthias Springer /// %x = vector.create_mask %a, ... : vector<dx...>
78399ef9eebSMatthias Springer /// is replaced by:
78499ef9eebSMatthias Springer /// %l = vector.create_mask ... : vector<...> ; one lower rank
78599ef9eebSMatthias Springer /// %0 = arith.cmpi "slt", %ci, %a |
78699ef9eebSMatthias Springer /// %1 = select %0, %l, %zeroes |
78799ef9eebSMatthias Springer /// %r = vector.insert %1, %pr [i] | d-times
78899ef9eebSMatthias Springer /// %x = ....
78999ef9eebSMatthias Springer /// until a one-dimensional vector is reached.
79099ef9eebSMatthias Springer class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
79199ef9eebSMatthias Springer public:
79299ef9eebSMatthias Springer using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
79399ef9eebSMatthias Springer
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const79499ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::CreateMaskOp op,
79599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
79699ef9eebSMatthias Springer auto dstType = op.getResult().getType().cast<VectorType>();
79799ef9eebSMatthias Springer int64_t rank = dstType.getRank();
79899ef9eebSMatthias Springer if (rank <= 1)
79999ef9eebSMatthias Springer return rewriter.notifyMatchFailure(
80099ef9eebSMatthias Springer op, "0-D and 1-D vectors are handled separately");
80199ef9eebSMatthias Springer
80299ef9eebSMatthias Springer auto loc = op.getLoc();
80399ef9eebSMatthias Springer auto eltType = dstType.getElementType();
80499ef9eebSMatthias Springer int64_t dim = dstType.getDimSize(0);
80599ef9eebSMatthias Springer Value idx = op.getOperand(0);
80699ef9eebSMatthias Springer
80799ef9eebSMatthias Springer VectorType lowType =
80899ef9eebSMatthias Springer VectorType::get(dstType.getShape().drop_front(), eltType);
80999ef9eebSMatthias Springer Value trueVal = rewriter.create<vector::CreateMaskOp>(
81099ef9eebSMatthias Springer loc, lowType, op.getOperands().drop_front());
81199ef9eebSMatthias Springer Value falseVal = rewriter.create<arith::ConstantOp>(
81299ef9eebSMatthias Springer loc, lowType, rewriter.getZeroAttr(lowType));
81399ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
81499ef9eebSMatthias Springer loc, dstType, rewriter.getZeroAttr(dstType));
81599ef9eebSMatthias Springer for (int64_t d = 0; d < dim; d++) {
81699ef9eebSMatthias Springer Value bnd =
81799ef9eebSMatthias Springer rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
81899ef9eebSMatthias Springer Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
81999ef9eebSMatthias Springer bnd, idx);
820dec8af70SRiver Riddle Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
82199ef9eebSMatthias Springer auto pos = rewriter.getI64ArrayAttr(d);
82299ef9eebSMatthias Springer result =
82399ef9eebSMatthias Springer rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
82499ef9eebSMatthias Springer }
82599ef9eebSMatthias Springer rewriter.replaceOp(op, result);
82699ef9eebSMatthias Springer return success();
82799ef9eebSMatthias Springer }
82899ef9eebSMatthias Springer };
82999ef9eebSMatthias Springer
83099ef9eebSMatthias Springer /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
83199ef9eebSMatthias Springer /// vectors progressively on the way to target llvm.matrix intrinsics.
83299ef9eebSMatthias Springer /// This iterates over the most major dimension of the 2-D vector and performs
83399ef9eebSMatthias Springer /// rewrites into:
83499ef9eebSMatthias Springer /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
83599ef9eebSMatthias Springer class ShapeCastOp2DDownCastRewritePattern
83699ef9eebSMatthias Springer : public OpRewritePattern<vector::ShapeCastOp> {
83799ef9eebSMatthias Springer public:
83899ef9eebSMatthias Springer using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
83999ef9eebSMatthias Springer
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const84099ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ShapeCastOp op,
84199ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
84299ef9eebSMatthias Springer auto sourceVectorType = op.getSourceVectorType();
84399ef9eebSMatthias Springer auto resultVectorType = op.getResultVectorType();
84499ef9eebSMatthias Springer if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
84599ef9eebSMatthias Springer return failure();
84699ef9eebSMatthias Springer
84799ef9eebSMatthias Springer auto loc = op.getLoc();
84899ef9eebSMatthias Springer Value desc = rewriter.create<arith::ConstantOp>(
84999ef9eebSMatthias Springer loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
85099ef9eebSMatthias Springer unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
85199ef9eebSMatthias Springer for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
8527c38fd60SJacques Pienaar Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
85399ef9eebSMatthias Springer desc = rewriter.create<vector::InsertStridedSliceOp>(
85499ef9eebSMatthias Springer loc, vec, desc,
85599ef9eebSMatthias Springer /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
85699ef9eebSMatthias Springer }
85799ef9eebSMatthias Springer rewriter.replaceOp(op, desc);
85899ef9eebSMatthias Springer return success();
85999ef9eebSMatthias Springer }
86099ef9eebSMatthias Springer };
86199ef9eebSMatthias Springer
86299ef9eebSMatthias Springer /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
86399ef9eebSMatthias Springer /// vectors progressively.
86499ef9eebSMatthias Springer /// This iterates over the most major dimension of the 2-D vector and performs
86599ef9eebSMatthias Springer /// rewrites into:
86699ef9eebSMatthias Springer /// vector.extract_strided_slice from 1-D + vector.insert into 2-D
86799ef9eebSMatthias Springer /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
86899ef9eebSMatthias Springer class ShapeCastOp2DUpCastRewritePattern
86999ef9eebSMatthias Springer : public OpRewritePattern<vector::ShapeCastOp> {
87099ef9eebSMatthias Springer public:
87199ef9eebSMatthias Springer using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
87299ef9eebSMatthias Springer
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const87399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ShapeCastOp op,
87499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
87599ef9eebSMatthias Springer auto sourceVectorType = op.getSourceVectorType();
87699ef9eebSMatthias Springer auto resultVectorType = op.getResultVectorType();
87799ef9eebSMatthias Springer if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
87899ef9eebSMatthias Springer return failure();
87999ef9eebSMatthias Springer
88099ef9eebSMatthias Springer auto loc = op.getLoc();
88199ef9eebSMatthias Springer Value desc = rewriter.create<arith::ConstantOp>(
88299ef9eebSMatthias Springer loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
88399ef9eebSMatthias Springer unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
88499ef9eebSMatthias Springer for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
88599ef9eebSMatthias Springer Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
8867c38fd60SJacques Pienaar loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
88799ef9eebSMatthias Springer /*sizes=*/mostMinorVectorSize,
88899ef9eebSMatthias Springer /*strides=*/1);
88999ef9eebSMatthias Springer desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
89099ef9eebSMatthias Springer }
89199ef9eebSMatthias Springer rewriter.replaceOp(op, desc);
89299ef9eebSMatthias Springer return success();
89399ef9eebSMatthias Springer }
89499ef9eebSMatthias Springer };
89599ef9eebSMatthias Springer
89699ef9eebSMatthias Springer // We typically should not lower general shape cast operations into data
89799ef9eebSMatthias Springer // movement instructions, since the assumption is that these casts are
89899ef9eebSMatthias Springer // optimized away during progressive lowering. For completeness, however,
89999ef9eebSMatthias Springer // we fall back to a reference implementation that moves all elements
90099ef9eebSMatthias Springer // into the right place if we get here.
90199ef9eebSMatthias Springer class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
90299ef9eebSMatthias Springer public:
90399ef9eebSMatthias Springer using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
90499ef9eebSMatthias Springer
matchAndRewrite(vector::ShapeCastOp op,PatternRewriter & rewriter) const90599ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ShapeCastOp op,
90699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
90799ef9eebSMatthias Springer Location loc = op.getLoc();
90899ef9eebSMatthias Springer auto sourceVectorType = op.getSourceVectorType();
90999ef9eebSMatthias Springer auto resultVectorType = op.getResultVectorType();
91099ef9eebSMatthias Springer
91199ef9eebSMatthias Springer // Special case 2D/1D lowerings with better implementations.
91299ef9eebSMatthias Springer // TODO: make is ND/1D to allow generic ND->1D->MD.
91399ef9eebSMatthias Springer int64_t srcRank = sourceVectorType.getRank();
91499ef9eebSMatthias Springer int64_t resRank = resultVectorType.getRank();
91599ef9eebSMatthias Springer if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
91699ef9eebSMatthias Springer return failure();
91799ef9eebSMatthias Springer
91899ef9eebSMatthias Springer // Generic ShapeCast lowering path goes all the way down to unrolled scalar
91999ef9eebSMatthias Springer // extract/insert chains.
92099ef9eebSMatthias Springer // TODO: consider evolving the semantics to only allow 1D source or dest and
92199ef9eebSMatthias Springer // drop this potentially very expensive lowering.
92299ef9eebSMatthias Springer // Compute number of elements involved in the reshape.
92399ef9eebSMatthias Springer int64_t numElts = 1;
92499ef9eebSMatthias Springer for (int64_t r = 0; r < srcRank; r++)
92599ef9eebSMatthias Springer numElts *= sourceVectorType.getDimSize(r);
92699ef9eebSMatthias Springer // Replace with data movement operations:
92799ef9eebSMatthias Springer // x[0,0,0] = y[0,0]
92899ef9eebSMatthias Springer // x[0,0,1] = y[0,1]
92999ef9eebSMatthias Springer // x[0,1,0] = y[0,2]
93099ef9eebSMatthias Springer // etc., incrementing the two index vectors "row-major"
93199ef9eebSMatthias Springer // within the source and result shape.
93299ef9eebSMatthias Springer SmallVector<int64_t, 4> srcIdx(srcRank);
93399ef9eebSMatthias Springer SmallVector<int64_t, 4> resIdx(resRank);
93499ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
93599ef9eebSMatthias Springer loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
93699ef9eebSMatthias Springer for (int64_t i = 0; i < numElts; i++) {
93799ef9eebSMatthias Springer if (i != 0) {
93899ef9eebSMatthias Springer incIdx(srcIdx, sourceVectorType, srcRank - 1);
93999ef9eebSMatthias Springer incIdx(resIdx, resultVectorType, resRank - 1);
94099ef9eebSMatthias Springer }
9417c38fd60SJacques Pienaar Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
94299ef9eebSMatthias Springer result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
94399ef9eebSMatthias Springer }
94499ef9eebSMatthias Springer rewriter.replaceOp(op, result);
94599ef9eebSMatthias Springer return success();
94699ef9eebSMatthias Springer }
94799ef9eebSMatthias Springer
94899ef9eebSMatthias Springer private:
incIdx(SmallVector<int64_t,4> & idx,VectorType tp,int64_t r)94999ef9eebSMatthias Springer static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
95099ef9eebSMatthias Springer assert(0 <= r && r < tp.getRank());
95199ef9eebSMatthias Springer if (++idx[r] == tp.getDimSize(r)) {
95299ef9eebSMatthias Springer idx[r] = 0;
95399ef9eebSMatthias Springer incIdx(idx, tp, r - 1);
95499ef9eebSMatthias Springer }
95599ef9eebSMatthias Springer }
95699ef9eebSMatthias Springer };
95799ef9eebSMatthias Springer
95899ef9eebSMatthias Springer /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
95999ef9eebSMatthias Springer /// Ex:
96099ef9eebSMatthias Springer /// ```
96199ef9eebSMatthias Springer /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
96299ef9eebSMatthias Springer /// %1 = vector.multi_reduction add, %0 [1]
96399ef9eebSMatthias Springer /// : vector<8x32x16xf32> to vector<8x16xf32>
96499ef9eebSMatthias Springer /// ```
96599ef9eebSMatthias Springer /// Gets converted to:
96699ef9eebSMatthias Springer /// ```
96799ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [
96899ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
96999ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
97099ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>],
97199ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"],
97299ef9eebSMatthias Springer /// kind = add} %0, %arg1, %cst_f0
97399ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
97499ef9eebSMatthias Springer /// ```
97599ef9eebSMatthias Springer struct MultiReduceToContract
97699ef9eebSMatthias Springer : public OpRewritePattern<vector::MultiDimReductionOp> {
97799ef9eebSMatthias Springer using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
97899ef9eebSMatthias Springer
matchAndRewrite__anon5c5a5b800211::MultiReduceToContract97999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
98099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
9817c38fd60SJacques Pienaar if (reduceOp.getKind() != vector::CombiningKind::ADD)
98299ef9eebSMatthias Springer return failure();
9837c38fd60SJacques Pienaar Operation *mulOp = reduceOp.getSource().getDefiningOp();
98499ef9eebSMatthias Springer if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
98599ef9eebSMatthias Springer return failure();
98699ef9eebSMatthias Springer SmallVector<bool> reductionMask = reduceOp.getReductionMask();
98799ef9eebSMatthias Springer auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
98899ef9eebSMatthias Springer SmallVector<AffineExpr> exprs;
98999ef9eebSMatthias Springer SmallVector<StringRef> iteratorTypes;
99099ef9eebSMatthias Springer for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
99199ef9eebSMatthias Springer if (!isReduceDim.value()) {
99299ef9eebSMatthias Springer iteratorTypes.push_back(getParallelIteratorTypeName());
99399ef9eebSMatthias Springer exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
99499ef9eebSMatthias Springer } else {
99599ef9eebSMatthias Springer iteratorTypes.push_back(getReductionIteratorTypeName());
99699ef9eebSMatthias Springer }
99799ef9eebSMatthias Springer }
99899ef9eebSMatthias Springer auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
99999ef9eebSMatthias Springer /*symCount=*/0, exprs, reduceOp.getContext());
100099ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
1001051b36baSThomas Raoux reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
100299ef9eebSMatthias Springer rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
100399ef9eebSMatthias Springer rewriter.getStrArrayAttr(iteratorTypes));
100499ef9eebSMatthias Springer return success();
100599ef9eebSMatthias Springer }
100699ef9eebSMatthias Springer };
100799ef9eebSMatthias Springer
100899ef9eebSMatthias Springer /// Merge TransposeOp into ContractionOp user.
100999ef9eebSMatthias Springer /// Ex:
101099ef9eebSMatthias Springer /// ```
101199ef9eebSMatthias Springer /// %0 = vector.transpose %arg0, [2, 0, 1]
101299ef9eebSMatthias Springer /// : vector<32x16x8xf32> to vector<8x32x16xf32>
101399ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [
101499ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
101599ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
101699ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>],
101799ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"],
101899ef9eebSMatthias Springer /// kind = add} %0, %arg1, %cst_f0
101999ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
102099ef9eebSMatthias Springer /// ```
102199ef9eebSMatthias Springer /// Gets converted to:
102299ef9eebSMatthias Springer /// ```
102399ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [
102499ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
102599ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
102699ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>],
102799ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"],
102899ef9eebSMatthias Springer /// kind = add} %arg0, %arg1, %cst_f0
102999ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
103099ef9eebSMatthias Springer /// ```
103199ef9eebSMatthias Springer struct CombineContractTranspose
103299ef9eebSMatthias Springer : public OpRewritePattern<vector::ContractionOp> {
103399ef9eebSMatthias Springer using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
103499ef9eebSMatthias Springer
matchAndRewrite__anon5c5a5b800211::CombineContractTranspose103599ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
103699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
103799ef9eebSMatthias Springer SmallVector<AffineMap, 4> maps =
1038*d2c0572bSJacques Pienaar llvm::to_vector<4>(contractOp.getIndexingMapsArray());
10397c38fd60SJacques Pienaar Value lhs = contractOp.getLhs();
10407c38fd60SJacques Pienaar Value rhs = contractOp.getRhs();
104199ef9eebSMatthias Springer size_t index = 0;
104299ef9eebSMatthias Springer bool changed = false;
104399ef9eebSMatthias Springer for (Value *operand : {&lhs, &rhs}) {
104499ef9eebSMatthias Springer AffineMap &map = maps[index++];
104599ef9eebSMatthias Springer auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
104699ef9eebSMatthias Springer if (!transposeOp)
104799ef9eebSMatthias Springer continue;
104899ef9eebSMatthias Springer SmallVector<int64_t> perm;
104999ef9eebSMatthias Springer transposeOp.getTransp(perm);
105099ef9eebSMatthias Springer AffineMap permutationMap = AffineMap::getPermutationMap(
10517c38fd60SJacques Pienaar extractVector<unsigned>(transposeOp.getTransp()),
105299ef9eebSMatthias Springer contractOp.getContext());
105399ef9eebSMatthias Springer map = inversePermutation(permutationMap).compose(map);
10547c38fd60SJacques Pienaar *operand = transposeOp.getVector();
105599ef9eebSMatthias Springer changed = true;
105699ef9eebSMatthias Springer }
105799ef9eebSMatthias Springer if (!changed)
105899ef9eebSMatthias Springer return failure();
105999ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ContractionOp>(
10607c38fd60SJacques Pienaar contractOp, lhs, rhs, contractOp.getAcc(),
10617c38fd60SJacques Pienaar rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
106299ef9eebSMatthias Springer return success();
106399ef9eebSMatthias Springer }
106499ef9eebSMatthias Springer };
106599ef9eebSMatthias Springer
106699ef9eebSMatthias Springer /// Merge BroadcastOp into ContractionOp user.
106799ef9eebSMatthias Springer /// Ex:
106899ef9eebSMatthias Springer /// ```
106999ef9eebSMatthias Springer /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
107099ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [
107199ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
107299ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
107399ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>],
107499ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"],
107599ef9eebSMatthias Springer /// kind = add} %0, %arg1, %cst_f0
107699ef9eebSMatthias Springer /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
107799ef9eebSMatthias Springer /// ```
107899ef9eebSMatthias Springer /// Gets converted to:
107999ef9eebSMatthias Springer /// ```
108099ef9eebSMatthias Springer /// %1 = vector.contract {indexing_maps = [
108199ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d1, d2)>,
108299ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
108399ef9eebSMatthias Springer /// affine_map<(d0, d1, d2) -> (d0, d1)>],
108499ef9eebSMatthias Springer /// iterator_types = ["parallel", "parallel", "reduction"],
108599ef9eebSMatthias Springer /// kind = add} %arg0, %arg1, %cst_f0
108699ef9eebSMatthias Springer /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
108799ef9eebSMatthias Springer /// ```
108899ef9eebSMatthias Springer struct CombineContractBroadcast
108999ef9eebSMatthias Springer : public OpRewritePattern<vector::ContractionOp> {
109099ef9eebSMatthias Springer using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
109199ef9eebSMatthias Springer
matchAndRewrite__anon5c5a5b800211::CombineContractBroadcast109299ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
109399ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
109499ef9eebSMatthias Springer SmallVector<AffineMap, 4> maps =
1095*d2c0572bSJacques Pienaar llvm::to_vector<4>(contractOp.getIndexingMapsArray());
10967c38fd60SJacques Pienaar Value lhs = contractOp.getLhs();
10977c38fd60SJacques Pienaar Value rhs = contractOp.getRhs();
109899ef9eebSMatthias Springer size_t index = 0;
109999ef9eebSMatthias Springer bool changed = false;
110099ef9eebSMatthias Springer for (Value *operand : {&lhs, &rhs}) {
110199ef9eebSMatthias Springer AffineMap &map = maps[index++];
110299ef9eebSMatthias Springer auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
110399ef9eebSMatthias Springer if (!broadcast)
110499ef9eebSMatthias Springer continue;
110599ef9eebSMatthias Springer // contractionOp can only take vector as operands.
110699ef9eebSMatthias Springer auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
110799ef9eebSMatthias Springer if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
110899ef9eebSMatthias Springer continue;
110999ef9eebSMatthias Springer int64_t rankDiff =
111099ef9eebSMatthias Springer broadcast.getVectorType().getRank() - srcType.getRank();
111199ef9eebSMatthias Springer bool innerDimBroadcast = false;
111299ef9eebSMatthias Springer SmallVector<AffineExpr> originalDims;
111399ef9eebSMatthias Springer for (const auto &dim : llvm::enumerate(srcType.getShape())) {
111499ef9eebSMatthias Springer if (dim.value() !=
111599ef9eebSMatthias Springer broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
111699ef9eebSMatthias Springer innerDimBroadcast = true;
111799ef9eebSMatthias Springer break;
111899ef9eebSMatthias Springer }
111999ef9eebSMatthias Springer originalDims.push_back(
112099ef9eebSMatthias Springer rewriter.getAffineDimExpr(dim.index() + rankDiff));
112199ef9eebSMatthias Springer }
112299ef9eebSMatthias Springer // Contract doesn't support inner dimension broadcast. Once this is
112399ef9eebSMatthias Springer // relaxed we can remove this case.
112499ef9eebSMatthias Springer if (innerDimBroadcast)
112599ef9eebSMatthias Springer continue;
1126694ad3eaSBenoit Jacob
1127694ad3eaSBenoit Jacob // It would be incorrect to fold a broadcast onto a reduction dimension
1128694ad3eaSBenoit Jacob // of non-unit size.
1129694ad3eaSBenoit Jacob bool nonUnitDimReductionBroadcast = false;
1130694ad3eaSBenoit Jacob for (int64_t i = 0; i < rankDiff; ++i) {
1131694ad3eaSBenoit Jacob if (broadcast.getVectorType().getDimSize(i) != 1 &&
1132694ad3eaSBenoit Jacob isReductionIterator(contractOp.getIteratorTypes()
1133694ad3eaSBenoit Jacob .getValue()[map.getDimPosition(i)])) {
1134694ad3eaSBenoit Jacob nonUnitDimReductionBroadcast = true;
1135694ad3eaSBenoit Jacob break;
1136694ad3eaSBenoit Jacob }
1137694ad3eaSBenoit Jacob }
1138694ad3eaSBenoit Jacob if (nonUnitDimReductionBroadcast)
1139694ad3eaSBenoit Jacob continue;
1140694ad3eaSBenoit Jacob
114199ef9eebSMatthias Springer AffineMap broadcastMap =
114299ef9eebSMatthias Springer AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
114399ef9eebSMatthias Springer contractOp.getContext());
114499ef9eebSMatthias Springer map = broadcastMap.compose(map);
11457c38fd60SJacques Pienaar *operand = broadcast.getSource();
114699ef9eebSMatthias Springer changed = true;
114799ef9eebSMatthias Springer }
1148694ad3eaSBenoit Jacob
114999ef9eebSMatthias Springer if (!changed)
115099ef9eebSMatthias Springer return failure();
1151694ad3eaSBenoit Jacob
1152694ad3eaSBenoit Jacob // Determine which dims are usused, now that the maps have been composed
1153694ad3eaSBenoit Jacob // with the broadcast maps.
1154c3839c0bSBenoit Jacob llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
1155694ad3eaSBenoit Jacob // Compress unused dims.
1156694ad3eaSBenoit Jacob for (auto &m : maps)
1157c3839c0bSBenoit Jacob m = compressDims(m, unusedDimsBitVector);
1158694ad3eaSBenoit Jacob // Compute the combined iterators.
1159694ad3eaSBenoit Jacob SmallVector<Attribute, 4> iterators;
1160c3839c0bSBenoit Jacob for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
1161c3839c0bSBenoit Jacob if (!unusedDimsBitVector.test(i))
1162694ad3eaSBenoit Jacob iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
1163694ad3eaSBenoit Jacob }
1164f0c3fd18SBenoit Jacob // Check that compressing unused dims isn't removing all reduction dimension
1165f0c3fd18SBenoit Jacob // pairs. For example, if the vector.contract had only one reduction
1166694ad3eaSBenoit Jacob // iterator and that was a unit-dimension created by a broadcast,
1167694ad3eaSBenoit Jacob // then we should bail here, otherwise we would create a contract without
1168f0c3fd18SBenoit Jacob // a reduction dimension pair.
1169f0c3fd18SBenoit Jacob bool hasReductionIteratorApplyingOnBothSides = false;
1170f0c3fd18SBenoit Jacob for (unsigned i = 0; i < iterators.size(); ++i) {
1171f0c3fd18SBenoit Jacob if (!isReductionIterator(iterators[i]))
1172f0c3fd18SBenoit Jacob continue;
1173f0c3fd18SBenoit Jacob if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
1174f0c3fd18SBenoit Jacob hasReductionIteratorApplyingOnBothSides = true;
1175f0c3fd18SBenoit Jacob break;
1176f0c3fd18SBenoit Jacob }
1177f0c3fd18SBenoit Jacob }
1178f0c3fd18SBenoit Jacob if (!hasReductionIteratorApplyingOnBothSides)
1179694ad3eaSBenoit Jacob return failure();
1180f0c3fd18SBenoit Jacob
1181c3839c0bSBenoit Jacob // If the compressed maps have a dimension that is not used by either LHS or
1182c3839c0bSBenoit Jacob // RHS then the ContractionOp verifier would fail.
1183c3839c0bSBenoit Jacob if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
1184c3839c0bSBenoit Jacob return failure();
118599ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ContractionOp>(
11867c38fd60SJacques Pienaar contractOp, lhs, rhs, contractOp.getAcc(),
1187694ad3eaSBenoit Jacob rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
118899ef9eebSMatthias Springer return success();
118999ef9eebSMatthias Springer }
119099ef9eebSMatthias Springer };
119199ef9eebSMatthias Springer
11921538bd51SHanhan Wang /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
11931538bd51SHanhan Wang /// contraction ops closer, which kicks in CombineContractBroadcast pattern when
11941538bd51SHanhan Wang /// casting ops are around these operations.
11951538bd51SHanhan Wang /// Ex:
11961538bd51SHanhan Wang /// ```
11971538bd51SHanhan Wang /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
11981538bd51SHanhan Wang /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
11991538bd51SHanhan Wang /// ```
12001538bd51SHanhan Wang /// Gets converted to:
12011538bd51SHanhan Wang /// ```
12021538bd51SHanhan Wang /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
12031538bd51SHanhan Wang /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
12041538bd51SHanhan Wang /// ```
12051538bd51SHanhan Wang struct ReorderCastOpsOnBroadcast
12061538bd51SHanhan Wang : public OpInterfaceRewritePattern<CastOpInterface> {
12071538bd51SHanhan Wang using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
12081538bd51SHanhan Wang
matchAndRewrite__anon5c5a5b800211::ReorderCastOpsOnBroadcast12091538bd51SHanhan Wang LogicalResult matchAndRewrite(CastOpInterface op,
12101538bd51SHanhan Wang PatternRewriter &rewriter) const override {
12111538bd51SHanhan Wang if (op->getNumOperands() != 1)
12121538bd51SHanhan Wang return failure();
12131538bd51SHanhan Wang auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
12141538bd51SHanhan Wang if (!bcastOp)
12151538bd51SHanhan Wang return failure();
12161538bd51SHanhan Wang
12171538bd51SHanhan Wang Type castResTy = getElementTypeOrSelf(op->getResult(0));
12181538bd51SHanhan Wang if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
12191538bd51SHanhan Wang castResTy = VectorType::get(vecTy.getShape(), castResTy);
122035f48edbSMehdi Amini auto *castOp =
12217c38fd60SJacques Pienaar rewriter.create(op->getLoc(), op->getName().getIdentifier(),
12227c38fd60SJacques Pienaar bcastOp.getSource(), castResTy, op->getAttrs());
12231538bd51SHanhan Wang rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
12241538bd51SHanhan Wang op, op->getResult(0).getType(), castOp->getResult(0));
12251538bd51SHanhan Wang return success();
12261538bd51SHanhan Wang }
12271538bd51SHanhan Wang };
12281538bd51SHanhan Wang
12294db65e27SLei Zhang /// Reorders elementwise(transpose) to transpose(elementwise). This makes
12304db65e27SLei Zhang /// transpose ops and contraction ops closer, which kicks in
12314db65e27SLei Zhang /// CombineContractTranspose pattern when elementwise ops are between these
12324db65e27SLei Zhang /// operations. Ex:
12331538bd51SHanhan Wang /// ```
12344db65e27SLei Zhang /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
12354db65e27SLei Zhang /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
12364db65e27SLei Zhang /// %r = arith.addf %at, %bt : vector<2x4xf32>
12371538bd51SHanhan Wang /// ```
12381538bd51SHanhan Wang /// Gets converted to:
12391538bd51SHanhan Wang /// ```
12404db65e27SLei Zhang /// %0 = arith.addf %a, %b : vector<4x2xf32>
12414db65e27SLei Zhang /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
12421538bd51SHanhan Wang /// ```
12434db65e27SLei Zhang struct ReorderElementwiseOpsOnTranspose final
12444db65e27SLei Zhang : public OpTraitRewritePattern<OpTrait::Elementwise> {
12454db65e27SLei Zhang using OpTraitRewritePattern::OpTraitRewritePattern;
matchAndRewrite__anon5c5a5b800211::ReorderElementwiseOpsOnTranspose12464db65e27SLei Zhang LogicalResult matchAndRewrite(Operation *op,
12471538bd51SHanhan Wang PatternRewriter &rewriter) const override {
12484db65e27SLei Zhang if (op->getNumResults() != 1 || op->getNumRegions() != 0)
12491538bd51SHanhan Wang return failure();
12501538bd51SHanhan Wang
12514db65e27SLei Zhang // Make sure all operands are transpose/constant ops and collect their
12524db65e27SLei Zhang // transposition maps.
12534db65e27SLei Zhang SmallVector<ArrayAttr, 4> transposeMaps;
12544db65e27SLei Zhang transposeMaps.reserve(op->getNumOperands());
12554db65e27SLei Zhang // Record the initial type before transposition. We'll use its shape later.
12564db65e27SLei Zhang // Any type will do here as we will check all transpose maps are the same.
12574db65e27SLei Zhang VectorType srcType;
12584db65e27SLei Zhang for (Value operand : op->getOperands()) {
12594db65e27SLei Zhang auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
12604db65e27SLei Zhang if (transposeOp) {
12614db65e27SLei Zhang transposeMaps.push_back(transposeOp.getTransp());
12624db65e27SLei Zhang srcType = transposeOp.getVectorType();
12634db65e27SLei Zhang } else if (!matchPattern(operand, m_Constant())) {
12644db65e27SLei Zhang return failure();
12654db65e27SLei Zhang }
12664db65e27SLei Zhang }
12674db65e27SLei Zhang if (transposeMaps.empty())
12684db65e27SLei Zhang return failure();
12694db65e27SLei Zhang // This is an elementwise op, so all transposed operands should have the
12704db65e27SLei Zhang // same type. We need to additionally check that all transposes uses the
12714db65e27SLei Zhang // same map.
12724db65e27SLei Zhang if (!llvm::is_splat(transposeMaps))
12734db65e27SLei Zhang return rewriter.notifyMatchFailure(op, "different transpose map");
12744db65e27SLei Zhang
12754db65e27SLei Zhang SmallVector<Value, 4> srcValues;
12764db65e27SLei Zhang srcValues.reserve(op->getNumOperands());
12774db65e27SLei Zhang
12784db65e27SLei Zhang // If there are constant operands, we need to insert inverse transposes for
12794db65e27SLei Zhang // them. Calculate the inverse order first.
12804db65e27SLei Zhang auto order = extractVector<unsigned>(transposeMaps.front());
12814db65e27SLei Zhang SmallVector<int64_t> invOrder(order.size());
12824db65e27SLei Zhang for (int i = 0, e = order.size(); i < e; ++i)
12834db65e27SLei Zhang invOrder[order[i]] = i;
12844db65e27SLei Zhang
12854db65e27SLei Zhang for (Value operand : op->getOperands()) {
12864db65e27SLei Zhang auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
12874db65e27SLei Zhang if (transposeOp) {
12884db65e27SLei Zhang srcValues.push_back(transposeOp.getVector());
12894db65e27SLei Zhang } else {
12904db65e27SLei Zhang // This is a constant. Create a reverse transpose op for it.
12914db65e27SLei Zhang auto vectorType = VectorType::get(
12924db65e27SLei Zhang srcType.getShape(),
12934db65e27SLei Zhang operand.getType().cast<VectorType>().getElementType());
12944db65e27SLei Zhang srcValues.push_back(rewriter.create<vector::TransposeOp>(
12954db65e27SLei Zhang operand.getLoc(), vectorType, operand,
12964db65e27SLei Zhang rewriter.getI64ArrayAttr(invOrder)));
12974db65e27SLei Zhang }
12984db65e27SLei Zhang }
12994db65e27SLei Zhang
13004db65e27SLei Zhang auto vectorType = VectorType::get(
13014db65e27SLei Zhang srcType.getShape(),
13024db65e27SLei Zhang op->getResultTypes()[0].cast<VectorType>().getElementType());
13034db65e27SLei Zhang Operation *elementwiseOp =
13044db65e27SLei Zhang rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
13054db65e27SLei Zhang vectorType, op->getAttrs());
13061538bd51SHanhan Wang rewriter.replaceOpWithNewOp<vector::TransposeOp>(
13074db65e27SLei Zhang op, op->getResultTypes()[0], elementwiseOp->getResult(0),
13084db65e27SLei Zhang transposeMaps.front());
13091538bd51SHanhan Wang return success();
13101538bd51SHanhan Wang }
13111538bd51SHanhan Wang };
13121538bd51SHanhan Wang
131399ef9eebSMatthias Springer } // namespace
131499ef9eebSMatthias Springer
131599ef9eebSMatthias Springer /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
131699ef9eebSMatthias Springer /// operands `x` and `y`.
createAdd(Location loc,Value x,Value y,bool isInt,PatternRewriter & rewriter)131799ef9eebSMatthias Springer static Value createAdd(Location loc, Value x, Value y, bool isInt,
131899ef9eebSMatthias Springer PatternRewriter &rewriter) {
131999ef9eebSMatthias Springer if (isInt)
132099ef9eebSMatthias Springer return rewriter.create<arith::AddIOp>(loc, x, y);
132199ef9eebSMatthias Springer return rewriter.create<arith::AddFOp>(loc, x, y);
132299ef9eebSMatthias Springer }
132399ef9eebSMatthias Springer
132499ef9eebSMatthias Springer /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
132599ef9eebSMatthias Springer /// operands `x and `y`.
createMul(Location loc,Value x,Value y,bool isInt,PatternRewriter & rewriter)132699ef9eebSMatthias Springer static Value createMul(Location loc, Value x, Value y, bool isInt,
132799ef9eebSMatthias Springer PatternRewriter &rewriter) {
132899ef9eebSMatthias Springer if (isInt)
132999ef9eebSMatthias Springer return rewriter.create<arith::MulIOp>(loc, x, y);
133099ef9eebSMatthias Springer return rewriter.create<arith::MulFOp>(loc, x, y);
133199ef9eebSMatthias Springer }
133299ef9eebSMatthias Springer
133399ef9eebSMatthias Springer namespace mlir {
133499ef9eebSMatthias Springer
133599ef9eebSMatthias Springer /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
133699ef9eebSMatthias Springer /// semantics to:
133799ef9eebSMatthias Springer /// ```
133899ef9eebSMatthias Springer /// %mta = maybe_transpose
133999ef9eebSMatthias Springer /// %mtb = maybe_transpose
134099ef9eebSMatthias Springer /// %flattened_a = vector.shape_cast %mta
134199ef9eebSMatthias Springer /// %flattened_b = vector.shape_cast %mtb
134299ef9eebSMatthias Springer /// %flattened_d = vector.matmul %flattened_a, %flattened_b
134399ef9eebSMatthias Springer /// %mtd = vector.shape_cast %flattened_d
134499ef9eebSMatthias Springer /// %d = maybe_untranspose %mtd
134599ef9eebSMatthias Springer /// %e = add %c, %d
134699ef9eebSMatthias Springer /// ```
134799ef9eebSMatthias Springer /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
134899ef9eebSMatthias Springer //
134999ef9eebSMatthias Springer /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
135099ef9eebSMatthias Springer /// vector.transpose operations are inserted if the vector.contract op is not a
135199ef9eebSMatthias Springer /// row-major matrix multiply.
135299ef9eebSMatthias Springer LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rew) const135399ef9eebSMatthias Springer ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
135499ef9eebSMatthias Springer PatternRewriter &rew) const {
135599ef9eebSMatthias Springer // TODO: implement masks
13567c38fd60SJacques Pienaar if (llvm::size(op.getMasks()) != 0)
135799ef9eebSMatthias Springer return failure();
135899ef9eebSMatthias Springer if (vectorTransformOptions.vectorContractLowering !=
135999ef9eebSMatthias Springer vector::VectorContractLowering::Matmul)
136099ef9eebSMatthias Springer return failure();
136199ef9eebSMatthias Springer if (failed(filter(op)))
136299ef9eebSMatthias Springer return failure();
136399ef9eebSMatthias Springer
13647c38fd60SJacques Pienaar auto iteratorTypes = op.getIteratorTypes().getValue();
136599ef9eebSMatthias Springer if (!isParallelIterator(iteratorTypes[0]) ||
136699ef9eebSMatthias Springer !isParallelIterator(iteratorTypes[1]) ||
136799ef9eebSMatthias Springer !isReductionIterator(iteratorTypes[2]))
136899ef9eebSMatthias Springer return failure();
136999ef9eebSMatthias Springer
137099ef9eebSMatthias Springer Type elementType = op.getLhsType().getElementType();
137199ef9eebSMatthias Springer if (!elementType.isIntOrFloat())
137299ef9eebSMatthias Springer return failure();
137399ef9eebSMatthias Springer
1374f011d32cSThomas Raoux Type dstElementType = op.getType();
1375f011d32cSThomas Raoux if (auto vecType = dstElementType.dyn_cast<VectorType>())
1376f011d32cSThomas Raoux dstElementType = vecType.getElementType();
1377f011d32cSThomas Raoux if (elementType != dstElementType)
1378f011d32cSThomas Raoux return failure();
1379f011d32cSThomas Raoux
138099ef9eebSMatthias Springer // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
138199ef9eebSMatthias Springer // Bail out if the contraction cannot be put in this form.
138299ef9eebSMatthias Springer MLIRContext *ctx = op.getContext();
138399ef9eebSMatthias Springer Location loc = op.getLoc();
138499ef9eebSMatthias Springer AffineExpr m, n, k;
138599ef9eebSMatthias Springer bindDims(rew.getContext(), m, n, k);
138699ef9eebSMatthias Springer // LHS must be A(m, k) or A(k, m).
13877c38fd60SJacques Pienaar Value lhs = op.getLhs();
1388*d2c0572bSJacques Pienaar auto lhsMap = op.getIndexingMapsArray()[0];
138999ef9eebSMatthias Springer if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
139099ef9eebSMatthias Springer lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
139199ef9eebSMatthias Springer else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
139299ef9eebSMatthias Springer return failure();
139399ef9eebSMatthias Springer
139499ef9eebSMatthias Springer // RHS must be B(k, n) or B(n, k).
13957c38fd60SJacques Pienaar Value rhs = op.getRhs();
1396*d2c0572bSJacques Pienaar auto rhsMap = op.getIndexingMapsArray()[1];
139799ef9eebSMatthias Springer if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
139899ef9eebSMatthias Springer rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
139999ef9eebSMatthias Springer else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
140099ef9eebSMatthias Springer return failure();
140199ef9eebSMatthias Springer
140299ef9eebSMatthias Springer // At this point lhs and rhs are in row-major.
140399ef9eebSMatthias Springer VectorType lhsType = lhs.getType().cast<VectorType>();
140499ef9eebSMatthias Springer VectorType rhsType = rhs.getType().cast<VectorType>();
140599ef9eebSMatthias Springer int64_t lhsRows = lhsType.getDimSize(0);
140699ef9eebSMatthias Springer int64_t lhsColumns = lhsType.getDimSize(1);
140799ef9eebSMatthias Springer int64_t rhsColumns = rhsType.getDimSize(1);
140899ef9eebSMatthias Springer
140999ef9eebSMatthias Springer Type flattenedLHSType =
141099ef9eebSMatthias Springer VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
141199ef9eebSMatthias Springer lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
141299ef9eebSMatthias Springer
141399ef9eebSMatthias Springer Type flattenedRHSType =
141499ef9eebSMatthias Springer VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
141599ef9eebSMatthias Springer rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
141699ef9eebSMatthias Springer
141799ef9eebSMatthias Springer Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
141899ef9eebSMatthias Springer rhsColumns);
141999ef9eebSMatthias Springer mul = rew.create<vector::ShapeCastOp>(
142099ef9eebSMatthias Springer loc,
142199ef9eebSMatthias Springer VectorType::get({lhsRows, rhsColumns},
14227c38fd60SJacques Pienaar getElementTypeOrSelf(op.getAcc().getType())),
142399ef9eebSMatthias Springer mul);
142499ef9eebSMatthias Springer
142599ef9eebSMatthias Springer // ACC must be C(m, n) or C(n, m).
1426*d2c0572bSJacques Pienaar auto accMap = op.getIndexingMapsArray()[2];
142799ef9eebSMatthias Springer if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
142899ef9eebSMatthias Springer mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
142999ef9eebSMatthias Springer else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
143099ef9eebSMatthias Springer llvm_unreachable("invalid contraction semantics");
143199ef9eebSMatthias Springer
143299ef9eebSMatthias Springer Value res =
143399ef9eebSMatthias Springer elementType.isa<IntegerType>()
14347c38fd60SJacques Pienaar ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
14357c38fd60SJacques Pienaar : static_cast<Value>(
14367c38fd60SJacques Pienaar rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
143799ef9eebSMatthias Springer
143899ef9eebSMatthias Springer rew.replaceOp(op, res);
143999ef9eebSMatthias Springer return success();
144099ef9eebSMatthias Springer }
144199ef9eebSMatthias Springer
144299ef9eebSMatthias Springer namespace {
144399ef9eebSMatthias Springer struct IteratorType {
IteratorTypemlir::__anon5c5a5b800311::IteratorType144499ef9eebSMatthias Springer IteratorType(StringRef strRef) : strRef(strRef) {}
isOfTypemlir::__anon5c5a5b800311::IteratorType144599ef9eebSMatthias Springer bool isOfType(Attribute attr) const {
144699ef9eebSMatthias Springer auto sAttr = attr.dyn_cast<StringAttr>();
144799ef9eebSMatthias Springer return sAttr && sAttr.getValue() == strRef;
144899ef9eebSMatthias Springer }
144999ef9eebSMatthias Springer StringRef strRef;
145099ef9eebSMatthias Springer };
145199ef9eebSMatthias Springer struct Par : public IteratorType {
Parmlir::__anon5c5a5b800311::Par145299ef9eebSMatthias Springer Par() : IteratorType(getParallelIteratorTypeName()) {}
145399ef9eebSMatthias Springer };
145499ef9eebSMatthias Springer struct Red : public IteratorType {
Redmlir::__anon5c5a5b800311::Red145599ef9eebSMatthias Springer Red() : IteratorType(getReductionIteratorTypeName()) {}
145699ef9eebSMatthias Springer };
145799ef9eebSMatthias Springer
145899ef9eebSMatthias Springer /// Generate a vector implementation for matmat, matvec and tmatvec.
145999ef9eebSMatthias Springer /// This unrolls outer-products along the reduction dimension.
146099ef9eebSMatthias Springer struct UnrolledOuterProductGenerator
146199ef9eebSMatthias Springer : public StructuredGenerator<vector::ContractionOp> {
UnrolledOuterProductGeneratormlir::__anon5c5a5b800311::UnrolledOuterProductGenerator146299ef9eebSMatthias Springer UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
146399ef9eebSMatthias Springer : StructuredGenerator<vector::ContractionOp>(builder, op),
14647c38fd60SJacques Pienaar kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
14657c38fd60SJacques Pienaar res(op.getAcc()), lhsType(op.getLhsType()) {}
146699ef9eebSMatthias Springer
tmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator146799ef9eebSMatthias Springer Value t(Value v) {
146899ef9eebSMatthias Springer static constexpr std::array<int64_t, 2> perm = {1, 0};
146999ef9eebSMatthias Springer return builder.create<vector::TransposeOp>(loc, v, perm);
147099ef9eebSMatthias Springer }
147199ef9eebSMatthias Springer
promotemlir::__anon5c5a5b800311::UnrolledOuterProductGenerator1472f011d32cSThomas Raoux Value promote(Value v, Type dstElementType) {
1473f011d32cSThomas Raoux Type elementType = v.getType();
1474f011d32cSThomas Raoux auto vecType = elementType.dyn_cast<VectorType>();
1475f011d32cSThomas Raoux if (vecType)
1476f011d32cSThomas Raoux elementType = vecType.getElementType();
1477f011d32cSThomas Raoux if (elementType == dstElementType)
1478f011d32cSThomas Raoux return v;
1479f011d32cSThomas Raoux Type promotedType = dstElementType;
1480f011d32cSThomas Raoux if (vecType)
1481f011d32cSThomas Raoux promotedType = VectorType::get(vecType.getShape(), promotedType);
1482f011d32cSThomas Raoux if (dstElementType.isa<FloatType>())
1483f011d32cSThomas Raoux return builder.create<arith::ExtFOp>(loc, promotedType, v);
1484f011d32cSThomas Raoux return builder.create<arith::ExtSIOp>(loc, promotedType, v);
1485f011d32cSThomas Raoux }
1486f011d32cSThomas Raoux
outerProdmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator148799ef9eebSMatthias Springer Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
148899ef9eebSMatthias Springer assert(reductionSize > 0);
1489f011d32cSThomas Raoux Type resElementType = res.getType().cast<VectorType>().getElementType();
149099ef9eebSMatthias Springer for (int64_t k = 0; k < reductionSize; ++k) {
149199ef9eebSMatthias Springer Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
149299ef9eebSMatthias Springer Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
1493f011d32cSThomas Raoux a = promote(a, resElementType);
1494f011d32cSThomas Raoux b = promote(b, resElementType);
149599ef9eebSMatthias Springer res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
149699ef9eebSMatthias Springer res, kind);
149799ef9eebSMatthias Springer }
149899ef9eebSMatthias Springer return res;
149999ef9eebSMatthias Springer }
150099ef9eebSMatthias Springer
150199ef9eebSMatthias Springer /// Two outer parallel, one inner reduction (matmat flavor).
matmatmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator150299ef9eebSMatthias Springer FailureOr<Value> matmat() {
150399ef9eebSMatthias Springer if (!iters({Par(), Par(), Red()}))
150499ef9eebSMatthias Springer return failure();
150599ef9eebSMatthias Springer // Set up the parallel/reduction structure in the right form.
150699ef9eebSMatthias Springer AffineExpr m, n, k;
150799ef9eebSMatthias Springer bindDims(builder.getContext(), m, n, k);
150899ef9eebSMatthias Springer // Classical row-major matmul: Just permute the lhs.
150999ef9eebSMatthias Springer if (layout({{m, k}, {k, n}, {m, n}}))
151099ef9eebSMatthias Springer return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
151199ef9eebSMatthias Springer // TODO: may be better to fail and use some vector<k> -> scalar reduction.
151299ef9eebSMatthias Springer if (layout({{m, k}, {n, k}, {m, n}})) {
151399ef9eebSMatthias Springer Value tlhs = t(lhs);
151499ef9eebSMatthias Springer return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
151599ef9eebSMatthias Springer }
151699ef9eebSMatthias Springer // No need to permute anything.
151799ef9eebSMatthias Springer if (layout({{k, m}, {k, n}, {m, n}}))
151899ef9eebSMatthias Springer return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
151999ef9eebSMatthias Springer // Just permute the rhs.
152099ef9eebSMatthias Springer if (layout({{k, m}, {n, k}, {m, n}}))
152199ef9eebSMatthias Springer return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
152299ef9eebSMatthias Springer // Transposed output: swap RHS and LHS.
152399ef9eebSMatthias Springer // Classical row-major matmul: permute the lhs.
152499ef9eebSMatthias Springer if (layout({{m, k}, {k, n}, {n, m}}))
152599ef9eebSMatthias Springer return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
152699ef9eebSMatthias Springer // TODO: may be better to fail and use some vector<k> -> scalar reduction.
152799ef9eebSMatthias Springer if (layout({{m, k}, {n, k}, {n, m}})) {
152899ef9eebSMatthias Springer Value trhs = t(rhs);
152999ef9eebSMatthias Springer return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
153099ef9eebSMatthias Springer }
153199ef9eebSMatthias Springer if (layout({{k, m}, {k, n}, {n, m}}))
153299ef9eebSMatthias Springer return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
153399ef9eebSMatthias Springer if (layout({{k, m}, {n, k}, {n, m}}))
153499ef9eebSMatthias Springer return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
153599ef9eebSMatthias Springer return failure();
153699ef9eebSMatthias Springer }
153799ef9eebSMatthias Springer
153899ef9eebSMatthias Springer /// One outer parallel, one inner reduction (matvec flavor)
matvecmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator153999ef9eebSMatthias Springer FailureOr<Value> matvec() {
154099ef9eebSMatthias Springer if (!iters({Par(), Red()}))
154199ef9eebSMatthias Springer return failure();
154299ef9eebSMatthias Springer AffineExpr m, k;
154399ef9eebSMatthias Springer bindDims(builder.getContext(), m, k);
154499ef9eebSMatthias Springer
154599ef9eebSMatthias Springer // Case mat-vec: transpose.
154699ef9eebSMatthias Springer if (layout({{m, k}, {k}, {m}}))
154799ef9eebSMatthias Springer return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
154899ef9eebSMatthias Springer // Case mat-trans-vec: ready to go.
154999ef9eebSMatthias Springer if (layout({{k, m}, {k}, {m}}))
155099ef9eebSMatthias Springer return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
155199ef9eebSMatthias Springer // Case vec-mat: swap and transpose.
155299ef9eebSMatthias Springer if (layout({{k}, {m, k}, {m}}))
155399ef9eebSMatthias Springer return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
155499ef9eebSMatthias Springer // Case vec-mat-trans: swap and ready to go.
155599ef9eebSMatthias Springer if (layout({{k}, {k, m}, {m}}))
155699ef9eebSMatthias Springer return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
155799ef9eebSMatthias Springer return failure();
155899ef9eebSMatthias Springer }
155999ef9eebSMatthias Springer
156099ef9eebSMatthias Springer //
156199ef9eebSMatthias Springer // One outer reduction, one inner parallel (tmatvec flavor)
156299ef9eebSMatthias Springer //
tmatvecmlir::__anon5c5a5b800311::UnrolledOuterProductGenerator156399ef9eebSMatthias Springer FailureOr<Value> tmatvec() {
156499ef9eebSMatthias Springer if (!iters({Red(), Par()}))
156599ef9eebSMatthias Springer return failure();
156699ef9eebSMatthias Springer AffineExpr k, m;
156799ef9eebSMatthias Springer bindDims(builder.getContext(), k, m);
156899ef9eebSMatthias Springer
156999ef9eebSMatthias Springer // Case mat-vec: transpose.
157099ef9eebSMatthias Springer if (layout({{m, k}, {k}, {m}}))
157199ef9eebSMatthias Springer return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
157299ef9eebSMatthias Springer // Case mat-trans-vec: ready to go.
157399ef9eebSMatthias Springer if (layout({{k, m}, {k}, {m}}))
157499ef9eebSMatthias Springer return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
157599ef9eebSMatthias Springer // Case vec-mat: swap and transpose.
157699ef9eebSMatthias Springer if (layout({{k}, {m, k}, {m}}))
157799ef9eebSMatthias Springer return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
157899ef9eebSMatthias Springer // Case vec-mat-trans: swap and ready to go.
157999ef9eebSMatthias Springer if (layout({{k}, {k, m}, {m}}))
158099ef9eebSMatthias Springer return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
158199ef9eebSMatthias Springer return failure();
158299ef9eebSMatthias Springer }
158399ef9eebSMatthias Springer
158499ef9eebSMatthias Springer private:
158599ef9eebSMatthias Springer vector::CombiningKind kind;
158699ef9eebSMatthias Springer Value lhs, rhs, res;
158799ef9eebSMatthias Springer VectorType lhsType;
158899ef9eebSMatthias Springer };
158999ef9eebSMatthias Springer } // namespace
159099ef9eebSMatthias Springer
159199ef9eebSMatthias Springer /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
159299ef9eebSMatthias Springer /// semantics to a reduction_size-unrolled sequence:
159399ef9eebSMatthias Springer /// ```
159499ef9eebSMatthias Springer /// %at = vector.transpose %a, [1, 0]
159599ef9eebSMatthias Springer /// %bRow0 = vector.extract %b[0]
159699ef9eebSMatthias Springer /// %atRow0 = vector.extract %at[0]
159799ef9eebSMatthias Springer /// %c0 = vector.outerproduct %atRow0, %bRow0, %c
159899ef9eebSMatthias Springer /// ...
159999ef9eebSMatthias Springer /// %bRowK = vector.extract %b[K]
160099ef9eebSMatthias Springer /// %atRowK = vector.extract %at[K]
160199ef9eebSMatthias Springer /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
160299ef9eebSMatthias Springer /// ```
160399ef9eebSMatthias Springer ///
160499ef9eebSMatthias Springer /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
160599ef9eebSMatthias Springer /// otherwise supports any layout permutation of the matrix-multiply.
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const160699ef9eebSMatthias Springer LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
160799ef9eebSMatthias Springer vector::ContractionOp op, PatternRewriter &rewriter) const {
160899ef9eebSMatthias Springer // TODO: implement masks
16097c38fd60SJacques Pienaar if (llvm::size(op.getMasks()) != 0)
161099ef9eebSMatthias Springer return failure();
161199ef9eebSMatthias Springer
161299ef9eebSMatthias Springer if (vectorTransformOptions.vectorContractLowering !=
161399ef9eebSMatthias Springer vector::VectorContractLowering::OuterProduct)
161499ef9eebSMatthias Springer return failure();
161599ef9eebSMatthias Springer
161699ef9eebSMatthias Springer if (failed(filter(op)))
161799ef9eebSMatthias Springer return failure();
161899ef9eebSMatthias Springer
161999ef9eebSMatthias Springer UnrolledOuterProductGenerator e(rewriter, op);
162099ef9eebSMatthias Springer FailureOr<Value> matmatRes = e.matmat();
162199ef9eebSMatthias Springer if (succeeded(matmatRes)) {
162299ef9eebSMatthias Springer rewriter.replaceOp(op, *matmatRes);
162399ef9eebSMatthias Springer return success();
162499ef9eebSMatthias Springer }
162599ef9eebSMatthias Springer FailureOr<Value> matvecRes = e.matvec();
162699ef9eebSMatthias Springer if (succeeded(matvecRes)) {
162799ef9eebSMatthias Springer rewriter.replaceOp(op, *matvecRes);
162899ef9eebSMatthias Springer return success();
162999ef9eebSMatthias Springer }
163099ef9eebSMatthias Springer FailureOr<Value> tmatvecRes = e.tmatvec();
163199ef9eebSMatthias Springer if (succeeded(tmatvecRes)) {
163299ef9eebSMatthias Springer rewriter.replaceOp(op, *tmatvecRes);
163399ef9eebSMatthias Springer return success();
163499ef9eebSMatthias Springer }
163599ef9eebSMatthias Springer
163699ef9eebSMatthias Springer return failure();
163799ef9eebSMatthias Springer }
163899ef9eebSMatthias Springer
163999ef9eebSMatthias Springer LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const164099ef9eebSMatthias Springer ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
164199ef9eebSMatthias Springer PatternRewriter &rewriter) const {
164299ef9eebSMatthias Springer // TODO: implement masks
16437c38fd60SJacques Pienaar if (llvm::size(op.getMasks()) != 0)
164499ef9eebSMatthias Springer return failure();
164599ef9eebSMatthias Springer
164699ef9eebSMatthias Springer if (failed(filter(op)))
164799ef9eebSMatthias Springer return failure();
164899ef9eebSMatthias Springer
164999ef9eebSMatthias Springer if (vectorTransformOptions.vectorContractLowering !=
165099ef9eebSMatthias Springer vector::VectorContractLowering::Dot)
165199ef9eebSMatthias Springer return failure();
165299ef9eebSMatthias Springer
16537c38fd60SJacques Pienaar auto iteratorTypes = op.getIteratorTypes().getValue();
165499ef9eebSMatthias Springer static constexpr std::array<int64_t, 2> perm = {1, 0};
165599ef9eebSMatthias Springer Location loc = op.getLoc();
16567c38fd60SJacques Pienaar Value lhs = op.getLhs(), rhs = op.getRhs();
165799ef9eebSMatthias Springer
165899ef9eebSMatthias Springer using MapList = ArrayRef<ArrayRef<AffineExpr>>;
165999ef9eebSMatthias Springer auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
166099ef9eebSMatthias Springer AffineExpr m, n, k;
166199ef9eebSMatthias Springer bindDims(rewriter.getContext(), m, n, k);
1662*d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
166399ef9eebSMatthias Springer //
166499ef9eebSMatthias Springer // In the following we wish to make the reduction dimension innermost so we
166599ef9eebSMatthias Springer // can load vectors and just fmul + reduce into a scalar.
166699ef9eebSMatthias Springer //
166799ef9eebSMatthias Springer if (isParallelIterator(iteratorTypes[0]) &&
166899ef9eebSMatthias Springer isParallelIterator(iteratorTypes[1]) &&
166999ef9eebSMatthias Springer isReductionIterator(iteratorTypes[2])) {
167099ef9eebSMatthias Springer //
167199ef9eebSMatthias Springer // Two outer parallel, one inner reduction (matmat flavor).
167299ef9eebSMatthias Springer //
167399ef9eebSMatthias Springer if (maps == infer({{m, k}, {k, n}, {m, n}})) {
167499ef9eebSMatthias Springer rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
167599ef9eebSMatthias Springer } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
167699ef9eebSMatthias Springer // No need to permute anything.
167799ef9eebSMatthias Springer } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
167899ef9eebSMatthias Springer lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
167999ef9eebSMatthias Springer rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
168099ef9eebSMatthias Springer } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
168199ef9eebSMatthias Springer lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
168299ef9eebSMatthias Springer } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
168399ef9eebSMatthias Springer // This is the classical row-major matmul. Just permute the lhs.
168499ef9eebSMatthias Springer Value tmp = lhs;
168599ef9eebSMatthias Springer lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
168699ef9eebSMatthias Springer rhs = tmp;
168799ef9eebSMatthias Springer } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
168899ef9eebSMatthias Springer std::swap(lhs, rhs);
168999ef9eebSMatthias Springer } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
169099ef9eebSMatthias Springer Value tmp = lhs;
169199ef9eebSMatthias Springer lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
169299ef9eebSMatthias Springer rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
169399ef9eebSMatthias Springer } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
169499ef9eebSMatthias Springer Value tmp = rhs;
169599ef9eebSMatthias Springer rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
169699ef9eebSMatthias Springer lhs = tmp;
169799ef9eebSMatthias Springer } else {
169899ef9eebSMatthias Springer return failure();
169999ef9eebSMatthias Springer }
170099ef9eebSMatthias Springer } else if (isParallelIterator(iteratorTypes[0]) &&
170199ef9eebSMatthias Springer isReductionIterator(iteratorTypes[1])) {
170299ef9eebSMatthias Springer //
170399ef9eebSMatthias Springer // One outer parallel, one inner reduction (matvec flavor)
170499ef9eebSMatthias Springer //
170599ef9eebSMatthias Springer if (maps == infer({{m, n}, {n}, {m}})) {
170699ef9eebSMatthias Springer // No need to permute anything.
170799ef9eebSMatthias Springer } else if (maps == infer({{n, m}, {n}, {m}})) {
170899ef9eebSMatthias Springer lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
170999ef9eebSMatthias Springer } else if (maps == infer({{n}, {m, n}, {m}})) {
171099ef9eebSMatthias Springer std::swap(lhs, rhs);
171199ef9eebSMatthias Springer } else if (maps == infer({{n}, {n, m}, {m}})) {
171299ef9eebSMatthias Springer std::swap(lhs, rhs);
171399ef9eebSMatthias Springer lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
171499ef9eebSMatthias Springer } else {
171599ef9eebSMatthias Springer return failure();
171699ef9eebSMatthias Springer }
171799ef9eebSMatthias Springer } else {
171899ef9eebSMatthias Springer return failure();
171999ef9eebSMatthias Springer }
172099ef9eebSMatthias Springer
172199ef9eebSMatthias Springer VectorType dstType = op.getResultType().cast<VectorType>();
172299ef9eebSMatthias Springer assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
172399ef9eebSMatthias Springer "Expected dst type of rank 1 or 2");
172499ef9eebSMatthias Springer
172599ef9eebSMatthias Springer unsigned rank = dstType.getRank();
172699ef9eebSMatthias Springer unsigned dstRows = dstType.getShape()[0];
172799ef9eebSMatthias Springer unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
172899ef9eebSMatthias Springer
172999ef9eebSMatthias Springer // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
173099ef9eebSMatthias Springer Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
173199ef9eebSMatthias Springer rewriter.getZeroAttr(dstType));
173299ef9eebSMatthias Springer bool isInt = dstType.getElementType().isa<IntegerType>();
173399ef9eebSMatthias Springer for (unsigned r = 0; r < dstRows; ++r) {
173499ef9eebSMatthias Springer Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
173599ef9eebSMatthias Springer for (unsigned c = 0; c < dstColumns; ++c) {
173699ef9eebSMatthias Springer Value b = rank == 1
173799ef9eebSMatthias Springer ? rhs
173899ef9eebSMatthias Springer : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
173999ef9eebSMatthias Springer Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
174099ef9eebSMatthias Springer Value reduced = rewriter.create<vector::ReductionOp>(
1741fe0bf7d4SMatthias Springer op.getLoc(), vector::CombiningKind::ADD, m);
174299ef9eebSMatthias Springer
174399ef9eebSMatthias Springer SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
174499ef9eebSMatthias Springer : SmallVector<int64_t, 2>{r, c};
174599ef9eebSMatthias Springer res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
174699ef9eebSMatthias Springer }
174799ef9eebSMatthias Springer }
17487c38fd60SJacques Pienaar if (auto acc = op.getAcc())
174999ef9eebSMatthias Springer res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
175099ef9eebSMatthias Springer rewriter.replaceOp(op, res);
175199ef9eebSMatthias Springer return success();
175299ef9eebSMatthias Springer }
175399ef9eebSMatthias Springer
175499ef9eebSMatthias Springer /// Progressive lowering of ContractionOp.
175599ef9eebSMatthias Springer /// One:
175699ef9eebSMatthias Springer /// %x = vector.contract with at least one free/batch dimension
175799ef9eebSMatthias Springer /// is replaced by:
175899ef9eebSMatthias Springer /// %a = vector.contract with one less free/batch dimension
175999ef9eebSMatthias Springer /// %b = vector.contract with one less free/batch dimension
176099ef9eebSMatthias Springer /// ..
176199ef9eebSMatthias Springer /// %x = combine %a %b ..
176299ef9eebSMatthias Springer /// until a pure contraction is reached (no free/batch dimensions),
176399ef9eebSMatthias Springer /// which is replaced by a dot-product.
176499ef9eebSMatthias Springer ///
176599ef9eebSMatthias Springer /// This only kicks in when either VectorTransformsOptions is set
176699ef9eebSMatthias Springer /// to DOT or when other contraction patterns fail.
176799ef9eebSMatthias Springer //
176899ef9eebSMatthias Springer // TODO: break down into transpose/reshape/cast ops
176999ef9eebSMatthias Springer // when they become available to avoid code dup
177099ef9eebSMatthias Springer // TODO: investigate lowering order impact on performance
177199ef9eebSMatthias Springer LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const177299ef9eebSMatthias Springer ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
177399ef9eebSMatthias Springer PatternRewriter &rewriter) const {
177499ef9eebSMatthias Springer // TODO: implement masks.
17757c38fd60SJacques Pienaar if (llvm::size(op.getMasks()) != 0)
177699ef9eebSMatthias Springer return failure();
177799ef9eebSMatthias Springer
177899ef9eebSMatthias Springer if (failed(filter(op)))
177999ef9eebSMatthias Springer return failure();
178099ef9eebSMatthias Springer
178199ef9eebSMatthias Springer // TODO: support mixed mode contract lowering.
178299ef9eebSMatthias Springer if (op.getLhsType().getElementType() !=
178399ef9eebSMatthias Springer getElementTypeOrSelf(op.getAccType()) ||
178499ef9eebSMatthias Springer op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
178599ef9eebSMatthias Springer return failure();
178699ef9eebSMatthias Springer
178799ef9eebSMatthias Springer // TODO: implement benefits, cost models.
178899ef9eebSMatthias Springer MLIRContext *ctx = op.getContext();
178999ef9eebSMatthias Springer ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
179099ef9eebSMatthias Springer if (succeeded(pat1.matchAndRewrite(op, rewriter)))
179199ef9eebSMatthias Springer return success();
179299ef9eebSMatthias Springer ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
179399ef9eebSMatthias Springer if (succeeded(pat2.matchAndRewrite(op, rewriter)))
179499ef9eebSMatthias Springer return success();
179599ef9eebSMatthias Springer ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
179699ef9eebSMatthias Springer if (succeeded(pat3.matchAndRewrite(op, rewriter)))
179799ef9eebSMatthias Springer return success();
179889aaa2d0SThomas Raoux ContractOpToElementwise pat4(vectorTransformOptions, ctx);
179989aaa2d0SThomas Raoux if (succeeded(pat4.matchAndRewrite(op, rewriter)))
180089aaa2d0SThomas Raoux return success();
180199ef9eebSMatthias Springer
180299ef9eebSMatthias Springer // Find first batch dimension in LHS/RHS, and lower when found.
180399ef9eebSMatthias Springer std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
180499ef9eebSMatthias Springer if (!batchDimMap.empty()) {
180599ef9eebSMatthias Springer int64_t lhsIndex = batchDimMap[0].first;
180699ef9eebSMatthias Springer int64_t rhsIndex = batchDimMap[0].second;
18076870a50fSBenoit Jacob auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
18086870a50fSBenoit Jacob if (failed(newOp))
18096870a50fSBenoit Jacob return failure();
18106870a50fSBenoit Jacob rewriter.replaceOp(op, newOp.value());
181199ef9eebSMatthias Springer return success();
181299ef9eebSMatthias Springer }
181399ef9eebSMatthias Springer
181499ef9eebSMatthias Springer // Collect contracting dimensions.
181599ef9eebSMatthias Springer std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
181699ef9eebSMatthias Springer op.getContractingDimMap();
181799ef9eebSMatthias Springer DenseSet<int64_t> lhsContractingDimSet;
181899ef9eebSMatthias Springer DenseSet<int64_t> rhsContractingDimSet;
181999ef9eebSMatthias Springer for (auto &dimPair : contractingDimMap) {
182099ef9eebSMatthias Springer lhsContractingDimSet.insert(dimPair.first);
182199ef9eebSMatthias Springer rhsContractingDimSet.insert(dimPair.second);
182299ef9eebSMatthias Springer }
182399ef9eebSMatthias Springer
182499ef9eebSMatthias Springer // Find first free dimension in LHS, and lower when found.
182599ef9eebSMatthias Springer VectorType lhsType = op.getLhsType();
182699ef9eebSMatthias Springer for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
182799ef9eebSMatthias Springer if (lhsContractingDimSet.count(lhsIndex) == 0) {
18286870a50fSBenoit Jacob auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
18296870a50fSBenoit Jacob if (failed(newOp))
18306870a50fSBenoit Jacob return failure();
18316870a50fSBenoit Jacob rewriter.replaceOp(op, newOp.value());
183299ef9eebSMatthias Springer return success();
183399ef9eebSMatthias Springer }
183499ef9eebSMatthias Springer }
183599ef9eebSMatthias Springer
183699ef9eebSMatthias Springer // Find first free dimension in RHS, and lower when found.
183799ef9eebSMatthias Springer VectorType rhsType = op.getRhsType();
183899ef9eebSMatthias Springer for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
183999ef9eebSMatthias Springer if (rhsContractingDimSet.count(rhsIndex) == 0) {
18406870a50fSBenoit Jacob auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
18416870a50fSBenoit Jacob if (failed(newOp))
18426870a50fSBenoit Jacob return failure();
18436870a50fSBenoit Jacob rewriter.replaceOp(op, newOp.value());
184499ef9eebSMatthias Springer return success();
184599ef9eebSMatthias Springer }
184699ef9eebSMatthias Springer }
184799ef9eebSMatthias Springer
184899ef9eebSMatthias Springer // Lower the first remaining reduction dimension.
184999ef9eebSMatthias Springer if (!contractingDimMap.empty()) {
18506870a50fSBenoit Jacob auto newOp = lowerReduction(op, rewriter);
18516870a50fSBenoit Jacob if (failed(newOp))
18526870a50fSBenoit Jacob return failure();
18536870a50fSBenoit Jacob rewriter.replaceOp(op, newOp.value());
185499ef9eebSMatthias Springer return success();
185599ef9eebSMatthias Springer }
185699ef9eebSMatthias Springer
185799ef9eebSMatthias Springer return failure();
185899ef9eebSMatthias Springer }
185999ef9eebSMatthias Springer
186099ef9eebSMatthias Springer // Lower one parallel dimension.
18616870a50fSBenoit Jacob // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
186299ef9eebSMatthias Springer // TODO: consider reusing existing contract unrolling
18636870a50fSBenoit Jacob FailureOr<Value>
lowerParallel(vector::ContractionOp op,int64_t lhsIndex,int64_t rhsIndex,PatternRewriter & rewriter) const18646870a50fSBenoit Jacob ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
18656870a50fSBenoit Jacob int64_t rhsIndex,
186699ef9eebSMatthias Springer PatternRewriter &rewriter) const {
186799ef9eebSMatthias Springer VectorType lhsType = op.getLhsType();
186899ef9eebSMatthias Springer VectorType rhsType = op.getRhsType();
186999ef9eebSMatthias Springer VectorType resType = op.getResultType().cast<VectorType>();
187099ef9eebSMatthias Springer // Find the iterator type index and result index.
1871*d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
187299ef9eebSMatthias Springer int64_t iterIndex = -1;
187399ef9eebSMatthias Springer int64_t dimSize = -1;
187499ef9eebSMatthias Springer if (lhsIndex >= 0) {
187599ef9eebSMatthias Springer iterIndex = iMap[0].getDimPosition(lhsIndex);
18766870a50fSBenoit Jacob if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
18776870a50fSBenoit Jacob return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
18786870a50fSBenoit Jacob diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
18796870a50fSBenoit Jacob << " to map to the same dimension";
18806870a50fSBenoit Jacob });
188199ef9eebSMatthias Springer dimSize = lhsType.getDimSize(lhsIndex);
18826870a50fSBenoit Jacob } else if (rhsIndex >= 0) {
188399ef9eebSMatthias Springer iterIndex = iMap[1].getDimPosition(rhsIndex);
188499ef9eebSMatthias Springer dimSize = rhsType.getDimSize(rhsIndex);
188599ef9eebSMatthias Springer }
18866870a50fSBenoit Jacob if (iterIndex < 0)
18876870a50fSBenoit Jacob return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
18886870a50fSBenoit Jacob diag << "expected either lhsIndex=" << lhsIndex
18896870a50fSBenoit Jacob << " or rhsIndex=" << rhsIndex << " to be nonnegative";
18906870a50fSBenoit Jacob });
18913c849d0aSFangrui Song // value_or(-1) means that we tolerate a dimension not appearing
18926870a50fSBenoit Jacob // in the result map. That can't happen for actual parallel iterators, but
18936870a50fSBenoit Jacob // the caller ContractionOpLowering::matchAndRewrite is currently calling
18946870a50fSBenoit Jacob // lowerParallel also for the case of unit-size reduction dims appearing only
18956870a50fSBenoit Jacob // on one of LHS or RHS, not both. At the moment, such cases are created by
18966870a50fSBenoit Jacob // CastAwayContractionLeadingOneDim, so we need to either support that or
18976870a50fSBenoit Jacob // modify that pattern.
18983c849d0aSFangrui Song int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
18996870a50fSBenoit Jacob if (resIndex == -1 && dimSize != 1)
19006870a50fSBenoit Jacob return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
19016870a50fSBenoit Jacob diag << "expected the dimension for iterIndex=" << iterIndex
19026870a50fSBenoit Jacob << " to either appear in the result map, or to be a unit dimension";
19036870a50fSBenoit Jacob });
190499ef9eebSMatthias Springer // Construct new iterator types and affine map array attribute.
190599ef9eebSMatthias Springer std::array<AffineMap, 3> lowIndexingMaps = {
190699ef9eebSMatthias Springer adjustMap(iMap[0], iterIndex, rewriter),
190799ef9eebSMatthias Springer adjustMap(iMap[1], iterIndex, rewriter),
190899ef9eebSMatthias Springer adjustMap(iMap[2], iterIndex, rewriter)};
190999ef9eebSMatthias Springer auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
191099ef9eebSMatthias Springer auto lowIter =
19117c38fd60SJacques Pienaar rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
191299ef9eebSMatthias Springer // Unroll into a series of lower dimensional vector.contract ops.
191399ef9eebSMatthias Springer Location loc = op.getLoc();
191499ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
191599ef9eebSMatthias Springer loc, resType, rewriter.getZeroAttr(resType));
191699ef9eebSMatthias Springer for (int64_t d = 0; d < dimSize; ++d) {
19177c38fd60SJacques Pienaar auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
19187c38fd60SJacques Pienaar auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
19197c38fd60SJacques Pienaar auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
192099ef9eebSMatthias Springer Value lowContract = rewriter.create<vector::ContractionOp>(
192199ef9eebSMatthias Springer loc, lhs, rhs, acc, lowAffine, lowIter);
192299ef9eebSMatthias Springer result =
192399ef9eebSMatthias Springer reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
192499ef9eebSMatthias Springer }
192599ef9eebSMatthias Springer return result;
192699ef9eebSMatthias Springer }
192799ef9eebSMatthias Springer
192899ef9eebSMatthias Springer // Lower one reduction dimension.
19296870a50fSBenoit Jacob FailureOr<Value>
lowerReduction(vector::ContractionOp op,PatternRewriter & rewriter) const19306870a50fSBenoit Jacob ContractionOpLowering::lowerReduction(vector::ContractionOp op,
193199ef9eebSMatthias Springer PatternRewriter &rewriter) const {
193299ef9eebSMatthias Springer auto loc = op.getLoc();
193399ef9eebSMatthias Springer VectorType lhsType = op.getLhsType();
193499ef9eebSMatthias Springer VectorType rhsType = op.getRhsType();
193599ef9eebSMatthias Springer Type resType = op.getResultType();
19366870a50fSBenoit Jacob if (resType.isa<VectorType>())
19376870a50fSBenoit Jacob return rewriter.notifyMatchFailure(op,
19386870a50fSBenoit Jacob "did not expect a VectorType result");
193999ef9eebSMatthias Springer bool isInt = resType.isa<IntegerType>();
194099ef9eebSMatthias Springer // Use iterator index 0.
194199ef9eebSMatthias Springer int64_t iterIndex = 0;
1942*d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
194399ef9eebSMatthias Springer Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
194499ef9eebSMatthias Springer Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
19453c849d0aSFangrui Song if (!lookupLhs.has_value())
19466870a50fSBenoit Jacob return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
19476870a50fSBenoit Jacob diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
19486870a50fSBenoit Jacob });
19493c849d0aSFangrui Song if (!lookupRhs.has_value())
19506870a50fSBenoit Jacob return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
19516870a50fSBenoit Jacob diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
19526870a50fSBenoit Jacob });
1953c27d8152SKazu Hirata int64_t lhsIndex = lookupLhs.value();
1954c27d8152SKazu Hirata int64_t rhsIndex = lookupRhs.value();
195599ef9eebSMatthias Springer int64_t dimSize = lhsType.getDimSize(lhsIndex);
19566870a50fSBenoit Jacob if (dimSize != rhsType.getDimSize(rhsIndex))
19576870a50fSBenoit Jacob return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
19586870a50fSBenoit Jacob diag << "expect LHS dimension " << lhsIndex
19596870a50fSBenoit Jacob << " to have the same size as RHS dimension " << rhsIndex;
19606870a50fSBenoit Jacob });
196199ef9eebSMatthias Springer // Base case.
196299ef9eebSMatthias Springer if (lhsType.getRank() == 1) {
19636870a50fSBenoit Jacob if (rhsType.getRank() != 1)
19646870a50fSBenoit Jacob return rewriter.notifyMatchFailure(
19656870a50fSBenoit Jacob op, "When LHS has rank 1, expected also RHS to have rank 1");
19667c38fd60SJacques Pienaar Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1967fe0bf7d4SMatthias Springer auto kind = vector::CombiningKind::ADD;
19687c38fd60SJacques Pienaar if (auto acc = op.getAcc())
19696870a50fSBenoit Jacob return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
19706870a50fSBenoit Jacob .getResult();
19716870a50fSBenoit Jacob return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
197299ef9eebSMatthias Springer }
197399ef9eebSMatthias Springer // Construct new iterator types and affine map array attribute.
197499ef9eebSMatthias Springer std::array<AffineMap, 3> lowIndexingMaps = {
197599ef9eebSMatthias Springer adjustMap(iMap[0], iterIndex, rewriter),
197699ef9eebSMatthias Springer adjustMap(iMap[1], iterIndex, rewriter),
197799ef9eebSMatthias Springer adjustMap(iMap[2], iterIndex, rewriter)};
197899ef9eebSMatthias Springer auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
197999ef9eebSMatthias Springer auto lowIter =
19807c38fd60SJacques Pienaar rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
198199ef9eebSMatthias Springer // Unroll into a series of lower dimensional vector.contract ops.
198299ef9eebSMatthias Springer // By feeding the initial accumulator into the first contraction,
198399ef9eebSMatthias Springer // and the result of each contraction into the next, eventually
198499ef9eebSMatthias Springer // the sum of all reductions is computed.
19857c38fd60SJacques Pienaar Value result = op.getAcc();
198699ef9eebSMatthias Springer for (int64_t d = 0; d < dimSize; ++d) {
19877c38fd60SJacques Pienaar auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
19887c38fd60SJacques Pienaar auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
198999ef9eebSMatthias Springer result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
199099ef9eebSMatthias Springer lowAffine, lowIter);
199199ef9eebSMatthias Springer }
199299ef9eebSMatthias Springer return result;
199399ef9eebSMatthias Springer }
199499ef9eebSMatthias Springer
199599ef9eebSMatthias Springer } // namespace mlir
199699ef9eebSMatthias Springer
distributPointwiseVectorOp(OpBuilder & builder,Operation * op,ArrayRef<Value> ids,ArrayRef<int64_t> multiplicity,const AffineMap & map)199799ef9eebSMatthias Springer Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
199899ef9eebSMatthias Springer OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
199999ef9eebSMatthias Springer ArrayRef<int64_t> multiplicity, const AffineMap &map) {
200099ef9eebSMatthias Springer OpBuilder::InsertionGuard guard(builder);
200199ef9eebSMatthias Springer builder.setInsertionPointAfter(op);
200299ef9eebSMatthias Springer Location loc = op->getLoc();
200399ef9eebSMatthias Springer if (op->getNumResults() != 1)
200499ef9eebSMatthias Springer return {};
200599ef9eebSMatthias Springer Value result = op->getResult(0);
200699ef9eebSMatthias Springer VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
200799ef9eebSMatthias Springer if (!type || map.getNumResults() != multiplicity.size())
200899ef9eebSMatthias Springer return {};
200999ef9eebSMatthias Springer // For each dimension being distributed check that the size is a multiple of
201099ef9eebSMatthias Springer // the multiplicity. To handle more sizes we would need to support masking.
201199ef9eebSMatthias Springer unsigned multiplictyCount = 0;
201299ef9eebSMatthias Springer for (auto exp : map.getResults()) {
201399ef9eebSMatthias Springer auto affinExp = exp.dyn_cast<AffineDimExpr>();
201499ef9eebSMatthias Springer if (!affinExp || affinExp.getPosition() >= type.getRank() ||
201599ef9eebSMatthias Springer type.getDimSize(affinExp.getPosition()) %
201699ef9eebSMatthias Springer multiplicity[multiplictyCount++] !=
201799ef9eebSMatthias Springer 0)
201899ef9eebSMatthias Springer return {};
201999ef9eebSMatthias Springer }
202099ef9eebSMatthias Springer DistributeOps ops;
202199ef9eebSMatthias Springer ops.extract =
202299ef9eebSMatthias Springer builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
202399ef9eebSMatthias Springer ops.insert =
202499ef9eebSMatthias Springer builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
202599ef9eebSMatthias Springer return ops;
202699ef9eebSMatthias Springer }
202799ef9eebSMatthias Springer
202899ef9eebSMatthias Springer /// Progressive lowering of transfer_read. This pattern supports lowering of
202999ef9eebSMatthias Springer /// `vector.transfer_read` to a combination of `vector.load` and
203099ef9eebSMatthias Springer /// `vector.broadcast` if all of the following hold:
203199ef9eebSMatthias Springer /// - Stride of most minor memref dimension must be 1.
203299ef9eebSMatthias Springer /// - Out-of-bounds masking is not required.
203399ef9eebSMatthias Springer /// - If the memref's element type is a vector type then it coincides with the
203499ef9eebSMatthias Springer /// result type.
203599ef9eebSMatthias Springer /// - The permutation map doesn't perform permutation (broadcasting is allowed).
203699ef9eebSMatthias Springer struct TransferReadToVectorLoadLowering
203799ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLoweringTransferReadToVectorLoadLowering203899ef9eebSMatthias Springer TransferReadToVectorLoadLowering(MLIRContext *context,
203999ef9eebSMatthias Springer llvm::Optional<unsigned> maxRank)
204099ef9eebSMatthias Springer : OpRewritePattern<vector::TransferReadOp>(context),
204199ef9eebSMatthias Springer maxTransferRank(maxRank) {}
204299ef9eebSMatthias Springer
matchAndRewriteTransferReadToVectorLoadLowering204399ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp read,
204499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
204599ef9eebSMatthias Springer if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
204699ef9eebSMatthias Springer return failure();
204799ef9eebSMatthias Springer
204899ef9eebSMatthias Springer SmallVector<unsigned, 4> broadcastedDims;
204999ef9eebSMatthias Springer // Permutations are handled by VectorToSCF or
205099ef9eebSMatthias Springer // populateVectorTransferPermutationMapLoweringPatterns.
205199ef9eebSMatthias Springer // We let the 0-d corner case pass-through as it is supported.
20527c38fd60SJacques Pienaar if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
205399ef9eebSMatthias Springer &broadcastedDims))
205499ef9eebSMatthias Springer return failure();
205599ef9eebSMatthias Springer
205699ef9eebSMatthias Springer auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
205799ef9eebSMatthias Springer if (!memRefType)
205899ef9eebSMatthias Springer return failure();
205999ef9eebSMatthias Springer
206099ef9eebSMatthias Springer // Non-unit strides are handled by VectorToSCF.
206199ef9eebSMatthias Springer if (!vector::isLastMemrefDimUnitStride(memRefType))
206299ef9eebSMatthias Springer return failure();
206399ef9eebSMatthias Springer
206499ef9eebSMatthias Springer // If there is broadcasting involved then we first load the unbroadcasted
206599ef9eebSMatthias Springer // vector, and then broadcast it with `vector.broadcast`.
206699ef9eebSMatthias Springer ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
206799ef9eebSMatthias Springer SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
206899ef9eebSMatthias Springer vectorShape.end());
206999ef9eebSMatthias Springer for (unsigned i : broadcastedDims)
207099ef9eebSMatthias Springer unbroadcastedVectorShape[i] = 1;
207199ef9eebSMatthias Springer VectorType unbroadcastedVectorType = VectorType::get(
207299ef9eebSMatthias Springer unbroadcastedVectorShape, read.getVectorType().getElementType());
207399ef9eebSMatthias Springer
207499ef9eebSMatthias Springer // `vector.load` supports vector types as memref's elements only when the
207599ef9eebSMatthias Springer // resulting vector type is the same as the element type.
207699ef9eebSMatthias Springer auto memrefElTy = memRefType.getElementType();
207799ef9eebSMatthias Springer if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
207899ef9eebSMatthias Springer return failure();
207999ef9eebSMatthias Springer
208099ef9eebSMatthias Springer // Otherwise, element types of the memref and the vector must match.
208199ef9eebSMatthias Springer if (!memrefElTy.isa<VectorType>() &&
208299ef9eebSMatthias Springer memrefElTy != read.getVectorType().getElementType())
208399ef9eebSMatthias Springer return failure();
208499ef9eebSMatthias Springer
208599ef9eebSMatthias Springer // Out-of-bounds dims are handled by MaterializeTransferMask.
208699ef9eebSMatthias Springer if (read.hasOutOfBoundsDim())
208799ef9eebSMatthias Springer return failure();
208899ef9eebSMatthias Springer
208999ef9eebSMatthias Springer // Create vector load op.
209099ef9eebSMatthias Springer Operation *loadOp;
20917c38fd60SJacques Pienaar if (read.getMask()) {
20926a8ba318SRiver Riddle Value fill = rewriter.create<vector::SplatOp>(
20937c38fd60SJacques Pienaar read.getLoc(), unbroadcastedVectorType, read.getPadding());
209499ef9eebSMatthias Springer loadOp = rewriter.create<vector::MaskedLoadOp>(
20957c38fd60SJacques Pienaar read.getLoc(), unbroadcastedVectorType, read.getSource(),
20967c38fd60SJacques Pienaar read.getIndices(), read.getMask(), fill);
209799ef9eebSMatthias Springer } else {
20987c38fd60SJacques Pienaar loadOp = rewriter.create<vector::LoadOp>(
20997c38fd60SJacques Pienaar read.getLoc(), unbroadcastedVectorType, read.getSource(),
21007c38fd60SJacques Pienaar read.getIndices());
210199ef9eebSMatthias Springer }
210299ef9eebSMatthias Springer
210399ef9eebSMatthias Springer // Insert a broadcasting op if required.
210499ef9eebSMatthias Springer if (!broadcastedDims.empty()) {
210599ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
210699ef9eebSMatthias Springer read, read.getVectorType(), loadOp->getResult(0));
210799ef9eebSMatthias Springer } else {
210899ef9eebSMatthias Springer rewriter.replaceOp(read, loadOp->getResult(0));
210999ef9eebSMatthias Springer }
211099ef9eebSMatthias Springer
211199ef9eebSMatthias Springer return success();
211299ef9eebSMatthias Springer }
211399ef9eebSMatthias Springer
211499ef9eebSMatthias Springer llvm::Optional<unsigned> maxTransferRank;
211599ef9eebSMatthias Springer };
211699ef9eebSMatthias Springer
211799ef9eebSMatthias Springer /// Replace a 0-d vector.load with a memref.load + vector.broadcast.
211899ef9eebSMatthias Springer // TODO: we shouldn't cross the vector/scalar domains just for this
211999ef9eebSMatthias Springer // but atm we lack the infra to avoid it. Possible solutions include:
212099ef9eebSMatthias Springer // - go directly to LLVM + bitcast
212199ef9eebSMatthias Springer // - introduce a bitcast op and likely a new pointer dialect
212299ef9eebSMatthias Springer // - let memref.load/store additionally support the 0-d vector case
212399ef9eebSMatthias Springer // There are still deeper data layout issues lingering even in this
212499ef9eebSMatthias Springer // trivial case (for architectures for which this matters).
212599ef9eebSMatthias Springer struct VectorLoadToMemrefLoadLowering
212699ef9eebSMatthias Springer : public OpRewritePattern<vector::LoadOp> {
212799ef9eebSMatthias Springer using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
212899ef9eebSMatthias Springer
matchAndRewriteVectorLoadToMemrefLoadLowering212999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::LoadOp loadOp,
213099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
213199ef9eebSMatthias Springer auto vecType = loadOp.getVectorType();
213299ef9eebSMatthias Springer if (vecType.getNumElements() != 1)
213399ef9eebSMatthias Springer return failure();
213499ef9eebSMatthias Springer auto memrefLoad = rewriter.create<memref::LoadOp>(
21357c38fd60SJacques Pienaar loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
213699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
213799ef9eebSMatthias Springer memrefLoad);
213899ef9eebSMatthias Springer return success();
213999ef9eebSMatthias Springer }
214099ef9eebSMatthias Springer };
214199ef9eebSMatthias Springer
214299ef9eebSMatthias Springer /// Replace a 0-d vector.store with a vector.extractelement + memref.store.
214399ef9eebSMatthias Springer struct VectorStoreToMemrefStoreLowering
214499ef9eebSMatthias Springer : public OpRewritePattern<vector::StoreOp> {
214599ef9eebSMatthias Springer using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
214699ef9eebSMatthias Springer
matchAndRewriteVectorStoreToMemrefStoreLowering214799ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::StoreOp storeOp,
214899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
214999ef9eebSMatthias Springer auto vecType = storeOp.getVectorType();
215099ef9eebSMatthias Springer if (vecType.getNumElements() != 1)
215199ef9eebSMatthias Springer return failure();
215299ef9eebSMatthias Springer Value extracted;
215399ef9eebSMatthias Springer if (vecType.getRank() == 0) {
215499ef9eebSMatthias Springer // TODO: Unifiy once ExtractOp supports 0-d vectors.
215599ef9eebSMatthias Springer extracted = rewriter.create<vector::ExtractElementOp>(
21567c38fd60SJacques Pienaar storeOp.getLoc(), storeOp.getValueToStore());
215799ef9eebSMatthias Springer } else {
215899ef9eebSMatthias Springer SmallVector<int64_t> indices(vecType.getRank(), 0);
215999ef9eebSMatthias Springer extracted = rewriter.create<vector::ExtractOp>(
21607c38fd60SJacques Pienaar storeOp.getLoc(), storeOp.getValueToStore(), indices);
216199ef9eebSMatthias Springer }
216299ef9eebSMatthias Springer
216399ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<memref::StoreOp>(
21647c38fd60SJacques Pienaar storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
216599ef9eebSMatthias Springer return success();
216699ef9eebSMatthias Springer }
216799ef9eebSMatthias Springer };
216899ef9eebSMatthias Springer
216999ef9eebSMatthias Springer /// Progressive lowering of transfer_write. This pattern supports lowering of
217099ef9eebSMatthias Springer /// `vector.transfer_write` to `vector.store` if all of the following hold:
217199ef9eebSMatthias Springer /// - Stride of most minor memref dimension must be 1.
217299ef9eebSMatthias Springer /// - Out-of-bounds masking is not required.
217399ef9eebSMatthias Springer /// - If the memref's element type is a vector type then it coincides with the
217499ef9eebSMatthias Springer /// type of the written value.
217599ef9eebSMatthias Springer /// - The permutation map is the minor identity map (neither permutation nor
217699ef9eebSMatthias Springer /// broadcasting is allowed).
217799ef9eebSMatthias Springer struct TransferWriteToVectorStoreLowering
217899ef9eebSMatthias Springer : public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLoweringTransferWriteToVectorStoreLowering217999ef9eebSMatthias Springer TransferWriteToVectorStoreLowering(MLIRContext *context,
218099ef9eebSMatthias Springer llvm::Optional<unsigned> maxRank)
218199ef9eebSMatthias Springer : OpRewritePattern<vector::TransferWriteOp>(context),
218299ef9eebSMatthias Springer maxTransferRank(maxRank) {}
218399ef9eebSMatthias Springer
matchAndRewriteTransferWriteToVectorStoreLowering218499ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferWriteOp write,
218599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
218699ef9eebSMatthias Springer if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
218799ef9eebSMatthias Springer return failure();
218899ef9eebSMatthias Springer
218999ef9eebSMatthias Springer // Permutations are handled by VectorToSCF or
219099ef9eebSMatthias Springer // populateVectorTransferPermutationMapLoweringPatterns.
219199ef9eebSMatthias Springer if ( // pass-through for the 0-d corner case.
21927c38fd60SJacques Pienaar !write.getPermutationMap().isMinorIdentity())
219399ef9eebSMatthias Springer return failure();
219499ef9eebSMatthias Springer
219599ef9eebSMatthias Springer auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
219699ef9eebSMatthias Springer if (!memRefType)
219799ef9eebSMatthias Springer return failure();
219899ef9eebSMatthias Springer
219999ef9eebSMatthias Springer // Non-unit strides are handled by VectorToSCF.
220099ef9eebSMatthias Springer if (!vector::isLastMemrefDimUnitStride(memRefType))
220199ef9eebSMatthias Springer return failure();
220299ef9eebSMatthias Springer
220399ef9eebSMatthias Springer // `vector.store` supports vector types as memref's elements only when the
220499ef9eebSMatthias Springer // type of the vector value being written is the same as the element type.
220599ef9eebSMatthias Springer auto memrefElTy = memRefType.getElementType();
220699ef9eebSMatthias Springer if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
220799ef9eebSMatthias Springer return failure();
220899ef9eebSMatthias Springer
220999ef9eebSMatthias Springer // Otherwise, element types of the memref and the vector must match.
221099ef9eebSMatthias Springer if (!memrefElTy.isa<VectorType>() &&
221199ef9eebSMatthias Springer memrefElTy != write.getVectorType().getElementType())
221299ef9eebSMatthias Springer return failure();
221399ef9eebSMatthias Springer
221499ef9eebSMatthias Springer // Out-of-bounds dims are handled by MaterializeTransferMask.
221599ef9eebSMatthias Springer if (write.hasOutOfBoundsDim())
221699ef9eebSMatthias Springer return failure();
22177c38fd60SJacques Pienaar if (write.getMask()) {
221899ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
22197c38fd60SJacques Pienaar write, write.getSource(), write.getIndices(), write.getMask(),
22207c38fd60SJacques Pienaar write.getVector());
222199ef9eebSMatthias Springer } else {
222299ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::StoreOp>(
22237c38fd60SJacques Pienaar write, write.getVector(), write.getSource(), write.getIndices());
222499ef9eebSMatthias Springer }
222599ef9eebSMatthias Springer return success();
222699ef9eebSMatthias Springer }
222799ef9eebSMatthias Springer
222899ef9eebSMatthias Springer llvm::Optional<unsigned> maxTransferRank;
222999ef9eebSMatthias Springer };
223099ef9eebSMatthias Springer
223199ef9eebSMatthias Springer // Returns the values in `arrayAttr` as an integer vector.
getIntValueVector(ArrayAttr arrayAttr)223299ef9eebSMatthias Springer static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
223399ef9eebSMatthias Springer return llvm::to_vector<4>(
223499ef9eebSMatthias Springer llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
223599ef9eebSMatthias Springer [](IntegerAttr attr) { return attr.getInt(); }));
223699ef9eebSMatthias Springer }
223799ef9eebSMatthias Springer
223899ef9eebSMatthias Springer // Shuffles vector.bitcast op after vector.extract op.
223999ef9eebSMatthias Springer //
224099ef9eebSMatthias Springer // This transforms IR like:
224199ef9eebSMatthias Springer // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
224299ef9eebSMatthias Springer // %1 = vector.extract %0[3] : vector<8xf16>
224399ef9eebSMatthias Springer // Into:
224499ef9eebSMatthias Springer // %0 = vector.extract %src[1] : vector<4xf32>
224599ef9eebSMatthias Springer // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
224699ef9eebSMatthias Springer // %2 = vector.extract %1[1] : vector<2xf16>
224799ef9eebSMatthias Springer struct BubbleDownVectorBitCastForExtract
224899ef9eebSMatthias Springer : public OpRewritePattern<vector::ExtractOp> {
224999ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern;
225099ef9eebSMatthias Springer
matchAndRewriteBubbleDownVectorBitCastForExtract225199ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
225299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
225399ef9eebSMatthias Springer // Only support extracting scalars for now.
225499ef9eebSMatthias Springer if (extractOp.getVectorType().getRank() != 1)
225599ef9eebSMatthias Springer return failure();
225699ef9eebSMatthias Springer
22577c38fd60SJacques Pienaar auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
225899ef9eebSMatthias Springer if (!castOp)
225999ef9eebSMatthias Springer return failure();
226099ef9eebSMatthias Springer
226199ef9eebSMatthias Springer VectorType castSrcType = castOp.getSourceVectorType();
226299ef9eebSMatthias Springer VectorType castDstType = castOp.getResultVectorType();
226399ef9eebSMatthias Springer assert(castSrcType.getRank() == castDstType.getRank());
226499ef9eebSMatthias Springer
226599ef9eebSMatthias Springer // Fail to match if we only have one element in the cast op source.
226699ef9eebSMatthias Springer // This is to avoid infinite loop given that this pattern can generate
226799ef9eebSMatthias Springer // such cases.
226899ef9eebSMatthias Springer if (castSrcType.getNumElements() == 1)
226999ef9eebSMatthias Springer return failure();
227099ef9eebSMatthias Springer
227199ef9eebSMatthias Springer // Only support casting to a larger number of elements or now.
227299ef9eebSMatthias Springer // E.g., vector<4xf32> -> vector<8xf16>.
227399ef9eebSMatthias Springer if (castSrcType.getNumElements() > castDstType.getNumElements())
227499ef9eebSMatthias Springer return failure();
227599ef9eebSMatthias Springer
227699ef9eebSMatthias Springer unsigned expandRatio =
227799ef9eebSMatthias Springer castDstType.getNumElements() / castSrcType.getNumElements();
227899ef9eebSMatthias Springer
227999ef9eebSMatthias Springer auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
228099ef9eebSMatthias Springer return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
228199ef9eebSMatthias Springer };
228299ef9eebSMatthias Springer
22837c38fd60SJacques Pienaar uint64_t index = getFirstIntValue(extractOp.getPosition());
228499ef9eebSMatthias Springer
228599ef9eebSMatthias Springer // Get the single scalar (as a vector) in the source value that packs the
228699ef9eebSMatthias Springer // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
228799ef9eebSMatthias Springer VectorType oneScalarType =
228899ef9eebSMatthias Springer VectorType::get({1}, castSrcType.getElementType());
228999ef9eebSMatthias Springer Value packedValue = rewriter.create<vector::ExtractOp>(
22907c38fd60SJacques Pienaar extractOp.getLoc(), oneScalarType, castOp.getSource(),
229199ef9eebSMatthias Springer rewriter.getI64ArrayAttr(index / expandRatio));
229299ef9eebSMatthias Springer
229399ef9eebSMatthias Springer // Cast it to a vector with the desired scalar's type.
229499ef9eebSMatthias Springer // E.g. f32 -> vector<2xf16>
229599ef9eebSMatthias Springer VectorType packedType =
229699ef9eebSMatthias Springer VectorType::get({expandRatio}, castDstType.getElementType());
229799ef9eebSMatthias Springer Value castedValue = rewriter.create<vector::BitCastOp>(
229899ef9eebSMatthias Springer extractOp.getLoc(), packedType, packedValue);
229999ef9eebSMatthias Springer
230099ef9eebSMatthias Springer // Finally extract the desired scalar.
230199ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ExtractOp>(
230299ef9eebSMatthias Springer extractOp, extractOp.getType(), castedValue,
230399ef9eebSMatthias Springer rewriter.getI64ArrayAttr(index % expandRatio));
230499ef9eebSMatthias Springer
230599ef9eebSMatthias Springer return success();
230699ef9eebSMatthias Springer }
230799ef9eebSMatthias Springer };
230899ef9eebSMatthias Springer
230999ef9eebSMatthias Springer // Shuffles vector.bitcast op after vector.extract_strided_slice op.
231099ef9eebSMatthias Springer //
231199ef9eebSMatthias Springer // This transforms IR like:
231299ef9eebSMatthias Springer // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
231399ef9eebSMatthias Springer // %0 = vector.extract_strided_slice %cast {
231499ef9eebSMatthias Springer // offsets = [4], sizes = [4], strides = [1]
231599ef9eebSMatthias Springer // } : vector<8xf16> to vector<4xf16>
231699ef9eebSMatthias Springer // Into:
231799ef9eebSMatthias Springer // %0 = vector.extract_strided_slice %src {
231899ef9eebSMatthias Springer // offsets = [2], sizes = [2], strides = [1]
231999ef9eebSMatthias Springer // } : vector<4xf32> to vector<2xf32>
232099ef9eebSMatthias Springer // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
232199ef9eebSMatthias Springer struct BubbleDownBitCastForStridedSliceExtract
232299ef9eebSMatthias Springer : public OpRewritePattern<vector::ExtractStridedSliceOp> {
232399ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern;
232499ef9eebSMatthias Springer
matchAndRewriteBubbleDownBitCastForStridedSliceExtract232599ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
232699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
23277c38fd60SJacques Pienaar auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
232899ef9eebSMatthias Springer if (!castOp)
232999ef9eebSMatthias Springer return failure();
233099ef9eebSMatthias Springer
233199ef9eebSMatthias Springer VectorType castSrcType = castOp.getSourceVectorType();
233299ef9eebSMatthias Springer VectorType castDstType = castOp.getResultVectorType();
233399ef9eebSMatthias Springer assert(castSrcType.getRank() == castDstType.getRank());
233499ef9eebSMatthias Springer
233599ef9eebSMatthias Springer int64_t castSrcLastDim = castSrcType.getShape().back();
233699ef9eebSMatthias Springer int64_t castDstLastDim = castDstType.getShape().back();
233799ef9eebSMatthias Springer // Require casting to more elements for now; other cases to be implemented.
233899ef9eebSMatthias Springer if (castSrcLastDim > castDstLastDim)
233999ef9eebSMatthias Springer return failure();
234099ef9eebSMatthias Springer
234199ef9eebSMatthias Springer // Only accept all one strides for now.
23427c38fd60SJacques Pienaar if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
234399ef9eebSMatthias Springer [](const APInt &val) { return !val.isOneValue(); }))
234499ef9eebSMatthias Springer return failure();
234599ef9eebSMatthias Springer
234699ef9eebSMatthias Springer unsigned rank = extractOp.getVectorType().getRank();
234799ef9eebSMatthias Springer assert(castDstLastDim % castSrcLastDim == 0);
234899ef9eebSMatthias Springer int64_t expandRatio = castDstLastDim / castSrcLastDim;
234999ef9eebSMatthias Springer
235099ef9eebSMatthias Springer // If we have a less number of offsets than the rank, then implicitly we
235199ef9eebSMatthias Springer // are selecting the full range for the last bitcasted dimension; other
235299ef9eebSMatthias Springer // dimensions aren't affected. Otherwise, we need to scale down the last
235399ef9eebSMatthias Springer // dimension's offset given we are extracting from less elements now.
23547c38fd60SJacques Pienaar ArrayAttr newOffsets = extractOp.getOffsets();
235599ef9eebSMatthias Springer if (newOffsets.size() == rank) {
235699ef9eebSMatthias Springer SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
235799ef9eebSMatthias Springer if (offsets.back() % expandRatio != 0)
235899ef9eebSMatthias Springer return failure();
235999ef9eebSMatthias Springer offsets.back() = offsets.back() / expandRatio;
236099ef9eebSMatthias Springer newOffsets = rewriter.getI64ArrayAttr(offsets);
236199ef9eebSMatthias Springer }
236299ef9eebSMatthias Springer
236399ef9eebSMatthias Springer // Similarly for sizes.
23647c38fd60SJacques Pienaar ArrayAttr newSizes = extractOp.getSizes();
236599ef9eebSMatthias Springer if (newSizes.size() == rank) {
236699ef9eebSMatthias Springer SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
236799ef9eebSMatthias Springer if (sizes.back() % expandRatio != 0)
236899ef9eebSMatthias Springer return failure();
236999ef9eebSMatthias Springer sizes.back() = sizes.back() / expandRatio;
237099ef9eebSMatthias Springer newSizes = rewriter.getI64ArrayAttr(sizes);
237199ef9eebSMatthias Springer }
237299ef9eebSMatthias Springer
237399ef9eebSMatthias Springer SmallVector<int64_t, 4> dims =
237499ef9eebSMatthias Springer llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
237599ef9eebSMatthias Springer dims.back() = dims.back() / expandRatio;
237699ef9eebSMatthias Springer VectorType newExtractType =
237799ef9eebSMatthias Springer VectorType::get(dims, castSrcType.getElementType());
237899ef9eebSMatthias Springer
237999ef9eebSMatthias Springer auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
23807c38fd60SJacques Pienaar extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
23817c38fd60SJacques Pienaar newSizes, extractOp.getStrides());
238299ef9eebSMatthias Springer
238399ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BitCastOp>(
238499ef9eebSMatthias Springer extractOp, extractOp.getType(), newExtractOp);
238599ef9eebSMatthias Springer
238699ef9eebSMatthias Springer return success();
238799ef9eebSMatthias Springer }
238899ef9eebSMatthias Springer };
238999ef9eebSMatthias Springer
239099ef9eebSMatthias Springer // Shuffles vector.bitcast op before vector.insert_strided_slice op.
239199ef9eebSMatthias Springer //
239299ef9eebSMatthias Springer // This transforms IR like:
239399ef9eebSMatthias Springer // %0 = vector.insert_strided_slice %src, %dst {
239499ef9eebSMatthias Springer // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
239599ef9eebSMatthias Springer // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
239699ef9eebSMatthias Springer // Into:
239799ef9eebSMatthias Springer // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
239899ef9eebSMatthias Springer // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
239999ef9eebSMatthias Springer // %2 = vector.insert_strided_slice %src, %dst {
240099ef9eebSMatthias Springer // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
240199ef9eebSMatthias Springer struct BubbleUpBitCastForStridedSliceInsert
240299ef9eebSMatthias Springer : public OpRewritePattern<vector::BitCastOp> {
240399ef9eebSMatthias Springer using OpRewritePattern::OpRewritePattern;
matchAndRewriteBubbleUpBitCastForStridedSliceInsert240499ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
240599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
240699ef9eebSMatthias Springer VectorType castSrcType = bitcastOp.getSourceVectorType();
240799ef9eebSMatthias Springer VectorType castDstType = bitcastOp.getResultVectorType();
240899ef9eebSMatthias Springer assert(castSrcType.getRank() == castDstType.getRank());
240999ef9eebSMatthias Springer
241099ef9eebSMatthias Springer int64_t castSrcLastDim = castSrcType.getShape().back();
241199ef9eebSMatthias Springer int64_t castDstLastDim = castDstType.getShape().back();
241299ef9eebSMatthias Springer // Require casting to less elements for now; other cases to be implemented.
241399ef9eebSMatthias Springer if (castSrcLastDim < castDstLastDim)
241499ef9eebSMatthias Springer return failure();
241599ef9eebSMatthias Springer
241699ef9eebSMatthias Springer assert(castSrcLastDim % castDstLastDim == 0);
241799ef9eebSMatthias Springer int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
241899ef9eebSMatthias Springer
241999ef9eebSMatthias Springer auto insertOp =
24207c38fd60SJacques Pienaar bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
242199ef9eebSMatthias Springer if (!insertOp)
242299ef9eebSMatthias Springer return failure();
242399ef9eebSMatthias Springer
242499ef9eebSMatthias Springer // Only accept all one strides for now.
24257c38fd60SJacques Pienaar if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
242699ef9eebSMatthias Springer [](const APInt &val) { return !val.isOneValue(); }))
242799ef9eebSMatthias Springer return failure();
242899ef9eebSMatthias Springer
242999ef9eebSMatthias Springer unsigned rank = insertOp.getSourceVectorType().getRank();
243099ef9eebSMatthias Springer // Require insert op to have the same rank for the source and destination
243199ef9eebSMatthias Springer // vector; other cases to be implemented.
243299ef9eebSMatthias Springer if (rank != insertOp.getDestVectorType().getRank())
243399ef9eebSMatthias Springer return failure();
243499ef9eebSMatthias Springer
24357c38fd60SJacques Pienaar ArrayAttr newOffsets = insertOp.getOffsets();
243699ef9eebSMatthias Springer assert(newOffsets.size() == rank);
243799ef9eebSMatthias Springer SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
243899ef9eebSMatthias Springer if (offsets.back() % shrinkRatio != 0)
243999ef9eebSMatthias Springer return failure();
244099ef9eebSMatthias Springer offsets.back() = offsets.back() / shrinkRatio;
244199ef9eebSMatthias Springer newOffsets = rewriter.getI64ArrayAttr(offsets);
244299ef9eebSMatthias Springer
244399ef9eebSMatthias Springer SmallVector<int64_t, 4> srcDims =
244499ef9eebSMatthias Springer llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
244599ef9eebSMatthias Springer srcDims.back() = srcDims.back() / shrinkRatio;
244699ef9eebSMatthias Springer VectorType newCastSrcType =
244799ef9eebSMatthias Springer VectorType::get(srcDims, castDstType.getElementType());
244899ef9eebSMatthias Springer
244999ef9eebSMatthias Springer auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
24507c38fd60SJacques Pienaar bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
245199ef9eebSMatthias Springer
245299ef9eebSMatthias Springer SmallVector<int64_t, 4> dstDims =
245399ef9eebSMatthias Springer llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
245499ef9eebSMatthias Springer dstDims.back() = dstDims.back() / shrinkRatio;
245599ef9eebSMatthias Springer VectorType newCastDstType =
245699ef9eebSMatthias Springer VectorType::get(dstDims, castDstType.getElementType());
245799ef9eebSMatthias Springer
245899ef9eebSMatthias Springer auto newCastDstOp = rewriter.create<vector::BitCastOp>(
24597c38fd60SJacques Pienaar bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
246099ef9eebSMatthias Springer
246199ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
246299ef9eebSMatthias Springer bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
24637c38fd60SJacques Pienaar insertOp.getStrides());
246499ef9eebSMatthias Springer
246599ef9eebSMatthias Springer return success();
246699ef9eebSMatthias Springer }
246799ef9eebSMatthias Springer };
246899ef9eebSMatthias Springer
246999ef9eebSMatthias Springer // Helper that returns a vector comparison that constructs a mask:
247099ef9eebSMatthias Springer // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
247199ef9eebSMatthias Springer //
247299ef9eebSMatthias Springer // If `dim == 0` then the result will be a 0-D vector.
247399ef9eebSMatthias Springer //
247499ef9eebSMatthias Springer // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
247599ef9eebSMatthias Springer // much more compact, IR for this operation, but LLVM eventually
247699ef9eebSMatthias Springer // generates more elaborate instructions for this intrinsic since it
247799ef9eebSMatthias Springer // is very conservative on the boundary conditions.
buildVectorComparison(PatternRewriter & rewriter,Operation * op,bool force32BitVectorIndices,int64_t dim,Value b,Value * off=nullptr)247899ef9eebSMatthias Springer static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
24797bc8ad51SJavier Setoain bool force32BitVectorIndices, int64_t dim,
248099ef9eebSMatthias Springer Value b, Value *off = nullptr) {
248199ef9eebSMatthias Springer auto loc = op->getLoc();
248299ef9eebSMatthias Springer // If we can assume all indices fit in 32-bit, we perform the vector
248399ef9eebSMatthias Springer // comparison in 32-bit to get a higher degree of SIMD parallelism.
248499ef9eebSMatthias Springer // Otherwise we perform the vector comparison using 64-bit indices.
248599ef9eebSMatthias Springer Type idxType =
24867bc8ad51SJavier Setoain force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
248799ef9eebSMatthias Springer DenseIntElementsAttr indicesAttr;
24887bc8ad51SJavier Setoain if (dim == 0 && force32BitVectorIndices) {
248999ef9eebSMatthias Springer indicesAttr = DenseIntElementsAttr::get(
249099ef9eebSMatthias Springer VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
249199ef9eebSMatthias Springer } else if (dim == 0) {
249299ef9eebSMatthias Springer indicesAttr = DenseIntElementsAttr::get(
249399ef9eebSMatthias Springer VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
24947bc8ad51SJavier Setoain } else if (force32BitVectorIndices) {
249599ef9eebSMatthias Springer indicesAttr = rewriter.getI32VectorAttr(
249699ef9eebSMatthias Springer llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
249799ef9eebSMatthias Springer } else {
249899ef9eebSMatthias Springer indicesAttr = rewriter.getI64VectorAttr(
249999ef9eebSMatthias Springer llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
250099ef9eebSMatthias Springer }
250199ef9eebSMatthias Springer Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
250299ef9eebSMatthias Springer // Add in an offset if requested.
250399ef9eebSMatthias Springer if (off) {
2504a75a46dbSJavier Setoain Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
25056a8ba318SRiver Riddle Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
250699ef9eebSMatthias Springer indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
250799ef9eebSMatthias Springer }
250899ef9eebSMatthias Springer // Construct the vector comparison.
2509a75a46dbSJavier Setoain Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
25106a8ba318SRiver Riddle Value bounds =
25116a8ba318SRiver Riddle rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
251299ef9eebSMatthias Springer return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
251399ef9eebSMatthias Springer bounds);
251499ef9eebSMatthias Springer }
251599ef9eebSMatthias Springer
251699ef9eebSMatthias Springer template <typename ConcreteOp>
251799ef9eebSMatthias Springer struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
251899ef9eebSMatthias Springer public:
MaterializeTransferMaskMaterializeTransferMask251999ef9eebSMatthias Springer explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
252099ef9eebSMatthias Springer : mlir::OpRewritePattern<ConcreteOp>(context),
25217bc8ad51SJavier Setoain force32BitVectorIndices(enableIndexOpt) {}
252299ef9eebSMatthias Springer
matchAndRewriteMaterializeTransferMask252399ef9eebSMatthias Springer LogicalResult matchAndRewrite(ConcreteOp xferOp,
252499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
252599ef9eebSMatthias Springer if (!xferOp.hasOutOfBoundsDim())
252699ef9eebSMatthias Springer return failure();
252799ef9eebSMatthias Springer
252899ef9eebSMatthias Springer if (xferOp.getVectorType().getRank() > 1 ||
25297c38fd60SJacques Pienaar llvm::size(xferOp.getIndices()) == 0)
253099ef9eebSMatthias Springer return failure();
253199ef9eebSMatthias Springer
253299ef9eebSMatthias Springer Location loc = xferOp->getLoc();
253399ef9eebSMatthias Springer VectorType vtp = xferOp.getVectorType();
253499ef9eebSMatthias Springer
2535f2b89c7aSJavier Setoain // Create the in-bounds mask with all elements between [0 .. dim - offset)
2536f2b89c7aSJavier Setoain // set and [dim - offset .. vector_length) unset.
253799ef9eebSMatthias Springer //
253899ef9eebSMatthias Springer // TODO: when the leaf transfer rank is k > 1, we need the last `k`
253999ef9eebSMatthias Springer // dimensions here.
25407c38fd60SJacques Pienaar unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
25417c38fd60SJacques Pienaar Value off = xferOp.getIndices()[lastIndex];
254299ef9eebSMatthias Springer Value dim =
25437c38fd60SJacques Pienaar vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
2544f2b89c7aSJavier Setoain Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
2545f2b89c7aSJavier Setoain Value mask = rewriter.create<vector::CreateMaskOp>(
2546f2b89c7aSJavier Setoain loc,
2547f2b89c7aSJavier Setoain VectorType::get(vtp.getShape(), rewriter.getI1Type(),
2548f2b89c7aSJavier Setoain vtp.getNumScalableDims()),
2549f2b89c7aSJavier Setoain b);
25507c38fd60SJacques Pienaar if (xferOp.getMask()) {
255199ef9eebSMatthias Springer // Intersect the in-bounds with the mask specified as an op parameter.
25527c38fd60SJacques Pienaar mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
255399ef9eebSMatthias Springer }
255499ef9eebSMatthias Springer
255599ef9eebSMatthias Springer rewriter.updateRootInPlace(xferOp, [&]() {
25567c38fd60SJacques Pienaar xferOp.getMaskMutable().assign(mask);
25577c38fd60SJacques Pienaar xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
255899ef9eebSMatthias Springer });
255999ef9eebSMatthias Springer
256099ef9eebSMatthias Springer return success();
256199ef9eebSMatthias Springer }
256299ef9eebSMatthias Springer
256399ef9eebSMatthias Springer private:
25647bc8ad51SJavier Setoain const bool force32BitVectorIndices;
256599ef9eebSMatthias Springer };
256699ef9eebSMatthias Springer
256799ef9eebSMatthias Springer /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
256899ef9eebSMatthias Springer class VectorCreateMaskOpConversion
256999ef9eebSMatthias Springer : public OpRewritePattern<vector::CreateMaskOp> {
257099ef9eebSMatthias Springer public:
VectorCreateMaskOpConversion(MLIRContext * context,bool enableIndexOpt)257199ef9eebSMatthias Springer explicit VectorCreateMaskOpConversion(MLIRContext *context,
257299ef9eebSMatthias Springer bool enableIndexOpt)
257399ef9eebSMatthias Springer : mlir::OpRewritePattern<vector::CreateMaskOp>(context),
25747bc8ad51SJavier Setoain force32BitVectorIndices(enableIndexOpt) {}
257599ef9eebSMatthias Springer
matchAndRewrite(vector::CreateMaskOp op,PatternRewriter & rewriter) const257699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::CreateMaskOp op,
257799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
257899ef9eebSMatthias Springer auto dstType = op.getType();
2579a75a46dbSJavier Setoain if (dstType.cast<VectorType>().isScalable())
2580a75a46dbSJavier Setoain return failure();
258199ef9eebSMatthias Springer int64_t rank = dstType.getRank();
258299ef9eebSMatthias Springer if (rank > 1)
258399ef9eebSMatthias Springer return failure();
258499ef9eebSMatthias Springer rewriter.replaceOp(
25857bc8ad51SJavier Setoain op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
258699ef9eebSMatthias Springer rank == 0 ? 0 : dstType.getDimSize(0),
258799ef9eebSMatthias Springer op.getOperand(0)));
258899ef9eebSMatthias Springer return success();
258999ef9eebSMatthias Springer }
259099ef9eebSMatthias Springer
259199ef9eebSMatthias Springer private:
25927bc8ad51SJavier Setoain const bool force32BitVectorIndices;
259399ef9eebSMatthias Springer };
259499ef9eebSMatthias Springer
259599ef9eebSMatthias Springer // Drop inner most contiguous unit dimensions from transfer_read operand.
259699ef9eebSMatthias Springer class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
259799ef9eebSMatthias Springer using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
259899ef9eebSMatthias Springer
matchAndRewrite(vector::TransferReadOp readOp,PatternRewriter & rewriter) const259999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
260099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
260199ef9eebSMatthias Springer // TODO: support 0-d corner case.
260299ef9eebSMatthias Springer if (readOp.getTransferRank() == 0)
260399ef9eebSMatthias Springer return failure();
260499ef9eebSMatthias Springer
260599ef9eebSMatthias Springer // TODO: support mask.
26067c38fd60SJacques Pienaar if (readOp.getMask())
260799ef9eebSMatthias Springer return failure();
260899ef9eebSMatthias Springer
26097c38fd60SJacques Pienaar auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
261099ef9eebSMatthias Springer if (!srcType || !srcType.hasStaticShape())
261199ef9eebSMatthias Springer return failure();
261299ef9eebSMatthias Springer
26137c38fd60SJacques Pienaar if (!readOp.getPermutationMap().isMinorIdentity())
261499ef9eebSMatthias Springer return failure();
261599ef9eebSMatthias Springer
261699ef9eebSMatthias Springer auto targetType = readOp.getVectorType();
261799ef9eebSMatthias Springer if (targetType.getRank() <= 1)
261899ef9eebSMatthias Springer return failure();
261999ef9eebSMatthias Springer
262099ef9eebSMatthias Springer SmallVector<int64_t> srcStrides;
262199ef9eebSMatthias Springer int64_t srcOffset;
262299ef9eebSMatthias Springer if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
262399ef9eebSMatthias Springer return failure();
262499ef9eebSMatthias Springer
262599ef9eebSMatthias Springer size_t dimsToDrop = 0;
262699ef9eebSMatthias Springer for (size_t i = 1; i < srcStrides.size(); ++i) {
262799ef9eebSMatthias Springer int dim = srcType.getRank() - i - 1;
262899ef9eebSMatthias Springer if (srcStrides[dim] == 1) {
262999ef9eebSMatthias Springer dimsToDrop++;
263099ef9eebSMatthias Springer } else {
263199ef9eebSMatthias Springer break;
263299ef9eebSMatthias Springer }
263399ef9eebSMatthias Springer }
263499ef9eebSMatthias Springer if (dimsToDrop == 0)
263599ef9eebSMatthias Springer return failure();
263699ef9eebSMatthias Springer
263799ef9eebSMatthias Springer auto resultTargetVecType =
263899ef9eebSMatthias Springer VectorType::get(targetType.getShape().drop_back(dimsToDrop),
263999ef9eebSMatthias Springer targetType.getElementType());
264099ef9eebSMatthias Springer
264199ef9eebSMatthias Springer MemRefType resultMemrefType;
264299ef9eebSMatthias Springer if (srcType.getLayout().getAffineMap().isIdentity()) {
264399ef9eebSMatthias Springer resultMemrefType = MemRefType::get(
264499ef9eebSMatthias Springer srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
264599ef9eebSMatthias Springer {}, srcType.getMemorySpaceAsInt());
264699ef9eebSMatthias Springer } else {
264799ef9eebSMatthias Springer AffineMap map = srcType.getLayout().getAffineMap();
264899ef9eebSMatthias Springer int numSymbols = map.getNumSymbols();
264999ef9eebSMatthias Springer for (size_t i = 0; i < dimsToDrop; ++i) {
265099ef9eebSMatthias Springer int dim = srcType.getRank() - i - 1;
265199ef9eebSMatthias Springer map = map.replace(rewriter.getAffineDimExpr(dim),
26524c1b65e7SThomas Raoux rewriter.getAffineConstantExpr(0),
26534c1b65e7SThomas Raoux map.getNumDims() - 1, numSymbols);
265499ef9eebSMatthias Springer }
265599ef9eebSMatthias Springer resultMemrefType = MemRefType::get(
265699ef9eebSMatthias Springer srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
265799ef9eebSMatthias Springer map, srcType.getMemorySpaceAsInt());
265899ef9eebSMatthias Springer }
265999ef9eebSMatthias Springer
266099ef9eebSMatthias Springer auto loc = readOp.getLoc();
266199ef9eebSMatthias Springer SmallVector<int64_t> offsets(srcType.getRank(), 0);
266299ef9eebSMatthias Springer SmallVector<int64_t> strides(srcType.getRank(), 1);
266399ef9eebSMatthias Springer
266499ef9eebSMatthias Springer ArrayAttr inBoundsAttr =
26657c38fd60SJacques Pienaar readOp.getInBounds()
266699ef9eebSMatthias Springer ? rewriter.getArrayAttr(
26677c38fd60SJacques Pienaar readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
266899ef9eebSMatthias Springer : ArrayAttr();
266999ef9eebSMatthias Springer Value rankedReducedView = rewriter.create<memref::SubViewOp>(
26707c38fd60SJacques Pienaar loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
267199ef9eebSMatthias Springer strides);
267299ef9eebSMatthias Springer auto permMap = getTransferMinorIdentityMap(
267399ef9eebSMatthias Springer rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
267499ef9eebSMatthias Springer Value result = rewriter.create<vector::TransferReadOp>(
267599ef9eebSMatthias Springer loc, resultTargetVecType, rankedReducedView,
26767c38fd60SJacques Pienaar readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
26777c38fd60SJacques Pienaar readOp.getPadding(),
267899ef9eebSMatthias Springer // TODO: support mask.
267999ef9eebSMatthias Springer /*mask=*/Value(), inBoundsAttr);
268099ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
268199ef9eebSMatthias Springer result);
268299ef9eebSMatthias Springer return success();
268399ef9eebSMatthias Springer }
268499ef9eebSMatthias Springer };
268599ef9eebSMatthias Springer
268699ef9eebSMatthias Springer namespace {
268799ef9eebSMatthias Springer
268899ef9eebSMatthias Springer /// This function checks to see if the vector combining kind
268999ef9eebSMatthias Springer /// is consistent with the integer or float element type.
isValidKind(bool isInt,vector::CombiningKind kind)269099ef9eebSMatthias Springer static bool isValidKind(bool isInt, vector::CombiningKind kind) {
269199ef9eebSMatthias Springer using vector::CombiningKind;
269299ef9eebSMatthias Springer enum class KindType { FLOAT, INT, INVALID };
269399ef9eebSMatthias Springer KindType type{KindType::INVALID};
269499ef9eebSMatthias Springer switch (kind) {
269599ef9eebSMatthias Springer case CombiningKind::MINF:
269699ef9eebSMatthias Springer case CombiningKind::MAXF:
269799ef9eebSMatthias Springer type = KindType::FLOAT;
269899ef9eebSMatthias Springer break;
269999ef9eebSMatthias Springer case CombiningKind::MINUI:
270099ef9eebSMatthias Springer case CombiningKind::MINSI:
270199ef9eebSMatthias Springer case CombiningKind::MAXUI:
270299ef9eebSMatthias Springer case CombiningKind::MAXSI:
270399ef9eebSMatthias Springer case CombiningKind::AND:
270499ef9eebSMatthias Springer case CombiningKind::OR:
270599ef9eebSMatthias Springer case CombiningKind::XOR:
270699ef9eebSMatthias Springer type = KindType::INT;
270799ef9eebSMatthias Springer break;
270899ef9eebSMatthias Springer case CombiningKind::ADD:
270999ef9eebSMatthias Springer case CombiningKind::MUL:
271099ef9eebSMatthias Springer type = isInt ? KindType::INT : KindType::FLOAT;
271199ef9eebSMatthias Springer break;
271299ef9eebSMatthias Springer }
271399ef9eebSMatthias Springer bool isValidIntKind = (type == KindType::INT) && isInt;
271499ef9eebSMatthias Springer bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
271599ef9eebSMatthias Springer return (isValidIntKind || isValidFloatKind);
271699ef9eebSMatthias Springer }
271799ef9eebSMatthias Springer
271899ef9eebSMatthias Springer /// This function constructs the appropriate integer or float
271999ef9eebSMatthias Springer /// operation given the vector combining kind and operands. The
272099ef9eebSMatthias Springer /// supported int operations are : add, mul, min (signed/unsigned),
272199ef9eebSMatthias Springer /// max(signed/unsigned), and, or, xor. The supported float
272299ef9eebSMatthias Springer /// operations are : add, mul, min and max.
genOperator(Location loc,Value x,Value y,vector::CombiningKind kind,PatternRewriter & rewriter)272399ef9eebSMatthias Springer static Value genOperator(Location loc, Value x, Value y,
272499ef9eebSMatthias Springer vector::CombiningKind kind,
272599ef9eebSMatthias Springer PatternRewriter &rewriter) {
272699ef9eebSMatthias Springer using vector::CombiningKind;
272799ef9eebSMatthias Springer
272899ef9eebSMatthias Springer auto elType = x.getType().cast<VectorType>().getElementType();
272999ef9eebSMatthias Springer bool isInt = elType.isIntOrIndex();
273099ef9eebSMatthias Springer
273199ef9eebSMatthias Springer Value combinedResult{nullptr};
273299ef9eebSMatthias Springer switch (kind) {
273399ef9eebSMatthias Springer case CombiningKind::ADD:
273499ef9eebSMatthias Springer if (isInt)
273599ef9eebSMatthias Springer combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
273699ef9eebSMatthias Springer else
273799ef9eebSMatthias Springer combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
273899ef9eebSMatthias Springer break;
273999ef9eebSMatthias Springer case CombiningKind::MUL:
274099ef9eebSMatthias Springer if (isInt)
274199ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
274299ef9eebSMatthias Springer else
274399ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
274499ef9eebSMatthias Springer break;
274599ef9eebSMatthias Springer case CombiningKind::MINUI:
274699ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
274799ef9eebSMatthias Springer break;
274899ef9eebSMatthias Springer case CombiningKind::MINSI:
274999ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
275099ef9eebSMatthias Springer break;
275199ef9eebSMatthias Springer case CombiningKind::MAXUI:
275299ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
275399ef9eebSMatthias Springer break;
275499ef9eebSMatthias Springer case CombiningKind::MAXSI:
275599ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
275699ef9eebSMatthias Springer break;
275799ef9eebSMatthias Springer case CombiningKind::AND:
275899ef9eebSMatthias Springer combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
275999ef9eebSMatthias Springer break;
276099ef9eebSMatthias Springer case CombiningKind::OR:
276199ef9eebSMatthias Springer combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
276299ef9eebSMatthias Springer break;
276399ef9eebSMatthias Springer case CombiningKind::XOR:
276499ef9eebSMatthias Springer combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
276599ef9eebSMatthias Springer break;
276699ef9eebSMatthias Springer case CombiningKind::MINF:
276799ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
276899ef9eebSMatthias Springer break;
276999ef9eebSMatthias Springer case CombiningKind::MAXF:
277099ef9eebSMatthias Springer combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
277199ef9eebSMatthias Springer break;
277299ef9eebSMatthias Springer }
277399ef9eebSMatthias Springer return combinedResult;
277499ef9eebSMatthias Springer }
277599ef9eebSMatthias Springer
277699ef9eebSMatthias Springer /// Convert vector.scan op into arith ops and
277799ef9eebSMatthias Springer /// vector.insert_strided_slice/extract_strided_slice
277899ef9eebSMatthias Springer ///
277999ef9eebSMatthias Springer /// Ex:
278099ef9eebSMatthias Springer /// ```
278199ef9eebSMatthias Springer /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim =
278299ef9eebSMatthias Springer /// 1} :
278399ef9eebSMatthias Springer /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
278499ef9eebSMatthias Springer /// ```
278599ef9eebSMatthias Springer /// Gets converted to:
278699ef9eebSMatthias Springer /// ```
278799ef9eebSMatthias Springer /// %cst = arith.constant dense<0> : vector<2x3xi32>
278899ef9eebSMatthias Springer /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1],
278999ef9eebSMatthias Springer /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 =
279099ef9eebSMatthias Springer /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]}
279199ef9eebSMatthias Springer /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice
279299ef9eebSMatthias Springer /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} :
279399ef9eebSMatthias Springer /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 :
279499ef9eebSMatthias Springer /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1],
279599ef9eebSMatthias Springer /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 =
279699ef9eebSMatthias Springer /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1],
279799ef9eebSMatthias Springer /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3,
279899ef9eebSMatthias Springer /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets =
279999ef9eebSMatthias Springer /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 =
280099ef9eebSMatthias Springer /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 :
280199ef9eebSMatthias Springer /// vector<2x3xi32>, vector<2xi32>
280299ef9eebSMatthias Springer /// ```
280399ef9eebSMatthias Springer struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
280499ef9eebSMatthias Springer using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
280599ef9eebSMatthias Springer
matchAndRewrite__anon5c5a5b801011::ScanToArithOps280699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::ScanOp scanOp,
280799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
280899ef9eebSMatthias Springer auto loc = scanOp.getLoc();
280999ef9eebSMatthias Springer VectorType destType = scanOp.getDestType();
281099ef9eebSMatthias Springer ArrayRef<int64_t> destShape = destType.getShape();
281199ef9eebSMatthias Springer auto elType = destType.getElementType();
281299ef9eebSMatthias Springer bool isInt = elType.isIntOrIndex();
28137c38fd60SJacques Pienaar if (!isValidKind(isInt, scanOp.getKind()))
281499ef9eebSMatthias Springer return failure();
281599ef9eebSMatthias Springer
281699ef9eebSMatthias Springer VectorType resType = VectorType::get(destShape, elType);
281799ef9eebSMatthias Springer Value result = rewriter.create<arith::ConstantOp>(
281899ef9eebSMatthias Springer loc, resType, rewriter.getZeroAttr(resType));
28197c38fd60SJacques Pienaar int64_t reductionDim = scanOp.getReductionDim();
28207c38fd60SJacques Pienaar bool inclusive = scanOp.getInclusive();
282199ef9eebSMatthias Springer int64_t destRank = destType.getRank();
282299ef9eebSMatthias Springer VectorType initialValueType = scanOp.getInitialValueType();
282399ef9eebSMatthias Springer int64_t initialValueRank = initialValueType.getRank();
282499ef9eebSMatthias Springer
282599ef9eebSMatthias Springer SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
282699ef9eebSMatthias Springer reductionShape[reductionDim] = 1;
282799ef9eebSMatthias Springer VectorType reductionType = VectorType::get(reductionShape, elType);
282899ef9eebSMatthias Springer SmallVector<int64_t> offsets(destRank, 0);
282999ef9eebSMatthias Springer SmallVector<int64_t> strides(destRank, 1);
283099ef9eebSMatthias Springer SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
283199ef9eebSMatthias Springer sizes[reductionDim] = 1;
283299ef9eebSMatthias Springer ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
283399ef9eebSMatthias Springer ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
283499ef9eebSMatthias Springer
283599ef9eebSMatthias Springer Value lastOutput, lastInput;
283699ef9eebSMatthias Springer for (int i = 0; i < destShape[reductionDim]; i++) {
283799ef9eebSMatthias Springer offsets[reductionDim] = i;
283899ef9eebSMatthias Springer ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
283999ef9eebSMatthias Springer Value input = rewriter.create<vector::ExtractStridedSliceOp>(
28407c38fd60SJacques Pienaar loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
284199ef9eebSMatthias Springer scanStrides);
284299ef9eebSMatthias Springer Value output;
284399ef9eebSMatthias Springer if (i == 0) {
284499ef9eebSMatthias Springer if (inclusive) {
284599ef9eebSMatthias Springer output = input;
284699ef9eebSMatthias Springer } else {
284799ef9eebSMatthias Springer if (initialValueRank == 0) {
284899ef9eebSMatthias Springer // ShapeCastOp cannot handle 0-D vectors
284999ef9eebSMatthias Springer output = rewriter.create<vector::BroadcastOp>(
28507c38fd60SJacques Pienaar loc, input.getType(), scanOp.getInitialValue());
285199ef9eebSMatthias Springer } else {
285299ef9eebSMatthias Springer output = rewriter.create<vector::ShapeCastOp>(
28537c38fd60SJacques Pienaar loc, input.getType(), scanOp.getInitialValue());
285499ef9eebSMatthias Springer }
285599ef9eebSMatthias Springer }
285699ef9eebSMatthias Springer } else {
285799ef9eebSMatthias Springer Value y = inclusive ? input : lastInput;
28587c38fd60SJacques Pienaar output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
285999ef9eebSMatthias Springer assert(output != nullptr);
286099ef9eebSMatthias Springer }
286199ef9eebSMatthias Springer result = rewriter.create<vector::InsertStridedSliceOp>(
286299ef9eebSMatthias Springer loc, output, result, offsets, strides);
286399ef9eebSMatthias Springer lastOutput = output;
286499ef9eebSMatthias Springer lastInput = input;
286599ef9eebSMatthias Springer }
286699ef9eebSMatthias Springer
286799ef9eebSMatthias Springer Value reduction;
286899ef9eebSMatthias Springer if (initialValueRank == 0) {
286999ef9eebSMatthias Springer Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
287099ef9eebSMatthias Springer reduction =
287199ef9eebSMatthias Springer rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
287299ef9eebSMatthias Springer } else {
287399ef9eebSMatthias Springer reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
287499ef9eebSMatthias Springer lastOutput);
287599ef9eebSMatthias Springer }
287699ef9eebSMatthias Springer
287799ef9eebSMatthias Springer rewriter.replaceOp(scanOp, {result, reduction});
287899ef9eebSMatthias Springer return success();
287999ef9eebSMatthias Springer }
288099ef9eebSMatthias Springer };
288199ef9eebSMatthias Springer
288299ef9eebSMatthias Springer } // namespace
288399ef9eebSMatthias Springer
populateVectorMaskMaterializationPatterns(RewritePatternSet & patterns,bool force32BitVectorIndices)288499ef9eebSMatthias Springer void mlir::vector::populateVectorMaskMaterializationPatterns(
28857bc8ad51SJavier Setoain RewritePatternSet &patterns, bool force32BitVectorIndices) {
288699ef9eebSMatthias Springer patterns.add<VectorCreateMaskOpConversion,
288799ef9eebSMatthias Springer MaterializeTransferMask<vector::TransferReadOp>,
288899ef9eebSMatthias Springer MaterializeTransferMask<vector::TransferWriteOp>>(
28897bc8ad51SJavier Setoain patterns.getContext(), force32BitVectorIndices);
289099ef9eebSMatthias Springer }
289199ef9eebSMatthias Springer
populateShapeCastFoldingPatterns(RewritePatternSet & patterns)289299ef9eebSMatthias Springer void mlir::vector::populateShapeCastFoldingPatterns(
289399ef9eebSMatthias Springer RewritePatternSet &patterns) {
289499ef9eebSMatthias Springer patterns.add<ShapeCastOpFolder>(patterns.getContext());
289599ef9eebSMatthias Springer }
289699ef9eebSMatthias Springer
populateBubbleVectorBitCastOpPatterns(RewritePatternSet & patterns)289799ef9eebSMatthias Springer void mlir::vector::populateBubbleVectorBitCastOpPatterns(
289899ef9eebSMatthias Springer RewritePatternSet &patterns) {
289999ef9eebSMatthias Springer patterns.add<BubbleDownVectorBitCastForExtract,
290099ef9eebSMatthias Springer BubbleDownBitCastForStridedSliceExtract,
290199ef9eebSMatthias Springer BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
290299ef9eebSMatthias Springer }
290399ef9eebSMatthias Springer
populateVectorBroadcastLoweringPatterns(RewritePatternSet & patterns)290499ef9eebSMatthias Springer void mlir::vector::populateVectorBroadcastLoweringPatterns(
290599ef9eebSMatthias Springer RewritePatternSet &patterns) {
290699ef9eebSMatthias Springer patterns.add<BroadcastOpLowering>(patterns.getContext());
290799ef9eebSMatthias Springer }
290899ef9eebSMatthias Springer
populateVectorMaskOpLoweringPatterns(RewritePatternSet & patterns)290999ef9eebSMatthias Springer void mlir::vector::populateVectorMaskOpLoweringPatterns(
291099ef9eebSMatthias Springer RewritePatternSet &patterns) {
291199ef9eebSMatthias Springer patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
291299ef9eebSMatthias Springer patterns.getContext());
291399ef9eebSMatthias Springer }
291499ef9eebSMatthias Springer
populateVectorShapeCastLoweringPatterns(RewritePatternSet & patterns)291599ef9eebSMatthias Springer void mlir::vector::populateVectorShapeCastLoweringPatterns(
291699ef9eebSMatthias Springer RewritePatternSet &patterns) {
291799ef9eebSMatthias Springer patterns.add<ShapeCastOp2DDownCastRewritePattern,
291899ef9eebSMatthias Springer ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
291999ef9eebSMatthias Springer patterns.getContext());
292099ef9eebSMatthias Springer }
292199ef9eebSMatthias Springer
populateVectorContractLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options)292299ef9eebSMatthias Springer void mlir::vector::populateVectorContractLoweringPatterns(
292399ef9eebSMatthias Springer RewritePatternSet &patterns, VectorTransformsOptions options) {
292499ef9eebSMatthias Springer patterns.add<OuterProductOpLowering>(patterns.getContext());
292599ef9eebSMatthias Springer patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
292699ef9eebSMatthias Springer ContractionOpToOuterProductOpLowering>(options,
292799ef9eebSMatthias Springer patterns.getContext());
292899ef9eebSMatthias Springer }
292999ef9eebSMatthias Springer
populateVectorTransposeLoweringPatterns(RewritePatternSet & patterns,VectorTransformsOptions options)293099ef9eebSMatthias Springer void mlir::vector::populateVectorTransposeLoweringPatterns(
293199ef9eebSMatthias Springer RewritePatternSet &patterns, VectorTransformsOptions options) {
293299ef9eebSMatthias Springer patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
293399ef9eebSMatthias Springer options, patterns.getContext());
293499ef9eebSMatthias Springer }
293599ef9eebSMatthias Springer
populateVectorReductionToContractPatterns(RewritePatternSet & patterns)293699ef9eebSMatthias Springer void mlir::vector::populateVectorReductionToContractPatterns(
293799ef9eebSMatthias Springer RewritePatternSet &patterns) {
293899ef9eebSMatthias Springer patterns.add<MultiReduceToContract, CombineContractBroadcast,
29391538bd51SHanhan Wang CombineContractTranspose, ReorderCastOpsOnBroadcast,
29404db65e27SLei Zhang ReorderElementwiseOpsOnTranspose>(patterns.getContext());
294199ef9eebSMatthias Springer }
294299ef9eebSMatthias Springer
294399ef9eebSMatthias Springer void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(RewritePatternSet & patterns)294499ef9eebSMatthias Springer populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
294599ef9eebSMatthias Springer RewritePatternSet &patterns) {
294699ef9eebSMatthias Springer patterns.add<DropInnerMostUnitDims>(patterns.getContext());
294799ef9eebSMatthias Springer }
294899ef9eebSMatthias Springer
populateVectorTransferLoweringPatterns(RewritePatternSet & patterns,llvm::Optional<unsigned> maxTransferRank)294999ef9eebSMatthias Springer void mlir::vector::populateVectorTransferLoweringPatterns(
295099ef9eebSMatthias Springer RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
295199ef9eebSMatthias Springer patterns.add<TransferReadToVectorLoadLowering,
295299ef9eebSMatthias Springer TransferWriteToVectorStoreLowering>(patterns.getContext(),
295399ef9eebSMatthias Springer maxTransferRank);
295499ef9eebSMatthias Springer patterns
295599ef9eebSMatthias Springer .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
295699ef9eebSMatthias Springer patterns.getContext());
295799ef9eebSMatthias Springer }
295899ef9eebSMatthias Springer
populateVectorScanLoweringPatterns(RewritePatternSet & patterns)295999ef9eebSMatthias Springer void mlir::vector::populateVectorScanLoweringPatterns(
296099ef9eebSMatthias Springer RewritePatternSet &patterns) {
296199ef9eebSMatthias Springer patterns.add<ScanToArithOps>(patterns.getContext());
296299ef9eebSMatthias Springer }
2963