1930c74f1SLei Zhang //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
26e557bc4SThomas Raoux //
36e557bc4SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46e557bc4SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
56e557bc4SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66e557bc4SThomas Raoux //
76e557bc4SThomas Raoux //===----------------------------------------------------------------------===//
86e557bc4SThomas Raoux //
9930c74f1SLei Zhang // This file implements patterns to convert Vector dialect to SPIRV dialect.
106e557bc4SThomas Raoux //
116e557bc4SThomas Raoux //===----------------------------------------------------------------------===//
126e557bc4SThomas Raoux
13930c74f1SLei Zhang #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14930c74f1SLei Zhang
156e557bc4SThomas Raoux #include "../PassDetail.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1701178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1901178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2136e68c11SLei Zhang #include "mlir/IR/BuiltinAttributes.h"
22d137c05fSLei Zhang #include "mlir/IR/BuiltinTypes.h"
236e557bc4SThomas Raoux #include "mlir/Transforms/DialectConversion.h"
24d137c05fSLei Zhang #include "llvm/ADT/ArrayRef.h"
2536e68c11SLei Zhang #include "llvm/ADT/STLExtras.h"
269f622b3dSLei Zhang #include <numeric>
276e557bc4SThomas Raoux
286e557bc4SThomas Raoux using namespace mlir;
296e557bc4SThomas Raoux
309f622b3dSLei Zhang /// Gets the first integer value from `attr`, assuming it is an integer array
319f622b3dSLei Zhang /// attribute.
getFirstIntValue(ArrayAttr attr)329f622b3dSLei Zhang static uint64_t getFirstIntValue(ArrayAttr attr) {
339f622b3dSLei Zhang return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
34801067f4SStella Stamenova }
359f622b3dSLei Zhang
366e557bc4SThomas Raoux namespace {
379f622b3dSLei Zhang
389f622b3dSLei Zhang struct VectorBitcastConvert final
399f622b3dSLei Zhang : public OpConversionPattern<vector::BitCastOp> {
409f622b3dSLei Zhang using OpConversionPattern::OpConversionPattern;
419f622b3dSLei Zhang
429f622b3dSLei Zhang LogicalResult
matchAndRewrite__anon0bb484910111::VectorBitcastConvert43b54c724bSRiver Riddle matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
449f622b3dSLei Zhang ConversionPatternRewriter &rewriter) const override {
459f622b3dSLei Zhang auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
469f622b3dSLei Zhang if (!dstType)
479f622b3dSLei Zhang return failure();
489f622b3dSLei Zhang
497c38fd60SJacques Pienaar if (dstType == adaptor.getSource().getType())
507c38fd60SJacques Pienaar rewriter.replaceOp(bitcastOp, adaptor.getSource());
519f622b3dSLei Zhang else
529f622b3dSLei Zhang rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
537c38fd60SJacques Pienaar adaptor.getSource());
549f622b3dSLei Zhang
559f622b3dSLei Zhang return success();
569f622b3dSLei Zhang }
579f622b3dSLei Zhang };
589f622b3dSLei Zhang
596e557bc4SThomas Raoux struct VectorBroadcastConvert final
607c3ae48fSLei Zhang : public OpConversionPattern<vector::BroadcastOp> {
617c3ae48fSLei Zhang using OpConversionPattern::OpConversionPattern;
627c3ae48fSLei Zhang
636e557bc4SThomas Raoux LogicalResult
matchAndRewrite__anon0bb484910111::VectorBroadcastConvert64b54c724bSRiver Riddle matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
656e557bc4SThomas Raoux ConversionPatternRewriter &rewriter) const override {
667c38fd60SJacques Pienaar if (broadcastOp.getSource().getType().isa<VectorType>() ||
676e557bc4SThomas Raoux !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
686e557bc4SThomas Raoux return failure();
696e557bc4SThomas Raoux SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
707c38fd60SJacques Pienaar adaptor.getSource());
713a56a966SLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
723a56a966SLei Zhang broadcastOp, broadcastOp.getVectorType(), source);
736e557bc4SThomas Raoux return success();
746e557bc4SThomas Raoux }
756e557bc4SThomas Raoux };
766e557bc4SThomas Raoux
776e557bc4SThomas Raoux struct VectorExtractOpConvert final
787c3ae48fSLei Zhang : public OpConversionPattern<vector::ExtractOp> {
797c3ae48fSLei Zhang using OpConversionPattern::OpConversionPattern;
807c3ae48fSLei Zhang
816e557bc4SThomas Raoux LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractOpConvert82b54c724bSRiver Riddle matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
836e557bc4SThomas Raoux ConversionPatternRewriter &rewriter) const override {
849f622b3dSLei Zhang // Only support extracting a scalar value now.
859f622b3dSLei Zhang VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
869f622b3dSLei Zhang if (resultVectorType && resultVectorType.getNumElements() > 1)
876e557bc4SThomas Raoux return failure();
889f622b3dSLei Zhang
899f622b3dSLei Zhang auto dstType = getTypeConverter()->convertType(extractOp.getType());
909f622b3dSLei Zhang if (!dstType)
919f622b3dSLei Zhang return failure();
929f622b3dSLei Zhang
937c38fd60SJacques Pienaar if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
947c38fd60SJacques Pienaar rewriter.replaceOp(extractOp, adaptor.getVector());
95b2e72cd3Sthomasraoux return success();
96b2e72cd3Sthomasraoux }
97b2e72cd3Sthomasraoux
987c38fd60SJacques Pienaar int32_t id = getFirstIntValue(extractOp.getPosition());
993a56a966SLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
1007c38fd60SJacques Pienaar extractOp, adaptor.getVector(), id);
1013a56a966SLei Zhang return success();
1023a56a966SLei Zhang }
1033a56a966SLei Zhang };
1043a56a966SLei Zhang
1059f622b3dSLei Zhang struct VectorExtractStridedSliceOpConvert final
1069f622b3dSLei Zhang : public OpConversionPattern<vector::ExtractStridedSliceOp> {
1079f622b3dSLei Zhang using OpConversionPattern::OpConversionPattern;
1089f622b3dSLei Zhang
1099f622b3dSLei Zhang LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractStridedSliceOpConvert110b54c724bSRiver Riddle matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
1119f622b3dSLei Zhang ConversionPatternRewriter &rewriter) const override {
1129f622b3dSLei Zhang auto dstType = getTypeConverter()->convertType(extractOp.getType());
1139f622b3dSLei Zhang if (!dstType)
1149f622b3dSLei Zhang return failure();
1159f622b3dSLei Zhang
1167c38fd60SJacques Pienaar uint64_t offset = getFirstIntValue(extractOp.getOffsets());
1177c38fd60SJacques Pienaar uint64_t size = getFirstIntValue(extractOp.getSizes());
1187c38fd60SJacques Pienaar uint64_t stride = getFirstIntValue(extractOp.getStrides());
1199f622b3dSLei Zhang if (stride != 1)
1209f622b3dSLei Zhang return failure();
1219f622b3dSLei Zhang
122b54c724bSRiver Riddle Value srcVector = adaptor.getOperands().front();
1239f622b3dSLei Zhang
124565ee6afSthomasraoux // Extract vector<1xT> case.
125565ee6afSthomasraoux if (dstType.isa<spirv::ScalarType>()) {
126565ee6afSthomasraoux rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
127565ee6afSthomasraoux srcVector, offset);
128565ee6afSthomasraoux return success();
129565ee6afSthomasraoux }
130565ee6afSthomasraoux
1319f622b3dSLei Zhang SmallVector<int32_t, 2> indices(size);
1329f622b3dSLei Zhang std::iota(indices.begin(), indices.end(), offset);
1339f622b3dSLei Zhang
1349f622b3dSLei Zhang rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
1359f622b3dSLei Zhang extractOp, dstType, srcVector, srcVector,
1369f622b3dSLei Zhang rewriter.getI32ArrayAttr(indices));
1379f622b3dSLei Zhang
1389f622b3dSLei Zhang return success();
1399f622b3dSLei Zhang }
1409f622b3dSLei Zhang };
1419f622b3dSLei Zhang
1423a56a966SLei Zhang struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
1433a56a966SLei Zhang using OpConversionPattern::OpConversionPattern;
1443a56a966SLei Zhang
1453a56a966SLei Zhang LogicalResult
matchAndRewrite__anon0bb484910111::VectorFmaOpConvert146b54c724bSRiver Riddle matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1473a56a966SLei Zhang ConversionPatternRewriter &rewriter) const override {
148a4360efbSLei Zhang Type dstType = getTypeConverter()->convertType(fmaOp.getType());
149a4360efbSLei Zhang if (!dstType)
1503a56a966SLei Zhang return failure();
151*52b630daSJakub Kuderski rewriter.replaceOpWithNewOp<spirv::GLFmaOp>(
152a4360efbSLei Zhang fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1536e557bc4SThomas Raoux return success();
1546e557bc4SThomas Raoux }
1556e557bc4SThomas Raoux };
1566e557bc4SThomas Raoux
1577c3ae48fSLei Zhang struct VectorInsertOpConvert final
1587c3ae48fSLei Zhang : public OpConversionPattern<vector::InsertOp> {
1597c3ae48fSLei Zhang using OpConversionPattern::OpConversionPattern;
1607c3ae48fSLei Zhang
1616e557bc4SThomas Raoux LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertOpConvert162b54c724bSRiver Riddle matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1636e557bc4SThomas Raoux ConversionPatternRewriter &rewriter) const override {
16447107508SLei Zhang // Special case for inserting scalar values into size-1 vectors.
16547107508SLei Zhang if (insertOp.getSourceType().isIntOrFloat() &&
16647107508SLei Zhang insertOp.getDestVectorType().getNumElements() == 1) {
1677c38fd60SJacques Pienaar rewriter.replaceOp(insertOp, adaptor.getSource());
16847107508SLei Zhang return success();
16947107508SLei Zhang }
17047107508SLei Zhang
1716e557bc4SThomas Raoux if (insertOp.getSourceType().isa<VectorType>() ||
1726e557bc4SThomas Raoux !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
1736e557bc4SThomas Raoux return failure();
1747c38fd60SJacques Pienaar int32_t id = getFirstIntValue(insertOp.getPosition());
1753a56a966SLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
1767c38fd60SJacques Pienaar insertOp, adaptor.getSource(), adaptor.getDest(), id);
1776e557bc4SThomas Raoux return success();
1786e557bc4SThomas Raoux }
1796e557bc4SThomas Raoux };
180f9dca103SArtur Bialas
181f9dca103SArtur Bialas struct VectorExtractElementOpConvert final
1827c3ae48fSLei Zhang : public OpConversionPattern<vector::ExtractElementOp> {
1837c3ae48fSLei Zhang using OpConversionPattern::OpConversionPattern;
1847c3ae48fSLei Zhang
185f9dca103SArtur Bialas LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractElementOpConvert186b54c724bSRiver Riddle matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
187f9dca103SArtur Bialas ConversionPatternRewriter &rewriter) const override {
188f9dca103SArtur Bialas if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
189f9dca103SArtur Bialas return failure();
1903a56a966SLei Zhang rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
1917c38fd60SJacques Pienaar extractElementOp, extractElementOp.getType(), adaptor.getVector(),
1927c38fd60SJacques Pienaar extractElementOp.getPosition());
193f9dca103SArtur Bialas return success();
194f9dca103SArtur Bialas }
195f9dca103SArtur Bialas };
196f9dca103SArtur Bialas
1973035e676SArtur Bialas struct VectorInsertElementOpConvert final
1987c3ae48fSLei Zhang : public OpConversionPattern<vector::InsertElementOp> {
1997c3ae48fSLei Zhang using OpConversionPattern::OpConversionPattern;
2007c3ae48fSLei Zhang
2013035e676SArtur Bialas LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertElementOpConvert202b54c724bSRiver Riddle matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
2033035e676SArtur Bialas ConversionPatternRewriter &rewriter) const override {
2043035e676SArtur Bialas if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
2053035e676SArtur Bialas return failure();
2063a56a966SLei Zhang rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
2077c38fd60SJacques Pienaar insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
2087c38fd60SJacques Pienaar adaptor.getSource(), insertElementOp.getPosition());
2093035e676SArtur Bialas return success();
2103035e676SArtur Bialas }
2113035e676SArtur Bialas };
2123035e676SArtur Bialas
2139f622b3dSLei Zhang struct VectorInsertStridedSliceOpConvert final
2149f622b3dSLei Zhang : public OpConversionPattern<vector::InsertStridedSliceOp> {
2159f622b3dSLei Zhang using OpConversionPattern::OpConversionPattern;
2169f622b3dSLei Zhang
2179f622b3dSLei Zhang LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertStridedSliceOpConvert218b54c724bSRiver Riddle matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
2199f622b3dSLei Zhang ConversionPatternRewriter &rewriter) const override {
220b54c724bSRiver Riddle Value srcVector = adaptor.getOperands().front();
221b54c724bSRiver Riddle Value dstVector = adaptor.getOperands().back();
2229f622b3dSLei Zhang
2237c38fd60SJacques Pienaar uint64_t stride = getFirstIntValue(insertOp.getStrides());
2249f622b3dSLei Zhang if (stride != 1)
2259f622b3dSLei Zhang return failure();
2267c38fd60SJacques Pienaar uint64_t offset = getFirstIntValue(insertOp.getOffsets());
22747107508SLei Zhang
22847107508SLei Zhang if (srcVector.getType().isa<spirv::ScalarType>()) {
22947107508SLei Zhang assert(!dstVector.getType().isa<spirv::ScalarType>());
23047107508SLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
23147107508SLei Zhang insertOp, dstVector.getType(), srcVector, dstVector,
23247107508SLei Zhang rewriter.getI32ArrayAttr(offset));
23347107508SLei Zhang return success();
23447107508SLei Zhang }
2359f622b3dSLei Zhang
2369f622b3dSLei Zhang uint64_t totalSize =
2379f622b3dSLei Zhang dstVector.getType().cast<VectorType>().getNumElements();
2389f622b3dSLei Zhang uint64_t insertSize =
2399f622b3dSLei Zhang srcVector.getType().cast<VectorType>().getNumElements();
2409f622b3dSLei Zhang
2419f622b3dSLei Zhang SmallVector<int32_t, 2> indices(totalSize);
2429f622b3dSLei Zhang std::iota(indices.begin(), indices.end(), 0);
2439f622b3dSLei Zhang std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
2449f622b3dSLei Zhang totalSize);
2459f622b3dSLei Zhang
2469f622b3dSLei Zhang rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
2479f622b3dSLei Zhang insertOp, dstVector.getType(), dstVector, srcVector,
2489f622b3dSLei Zhang rewriter.getI32ArrayAttr(indices));
2499f622b3dSLei Zhang
2509f622b3dSLei Zhang return success();
2519f622b3dSLei Zhang }
2529f622b3dSLei Zhang };
2539f622b3dSLei Zhang
254d137c05fSLei Zhang struct VectorReductionPattern final
255d137c05fSLei Zhang : public OpConversionPattern<vector::ReductionOp> {
256d137c05fSLei Zhang using OpConversionPattern::OpConversionPattern;
257d137c05fSLei Zhang
258d137c05fSLei Zhang LogicalResult
matchAndRewrite__anon0bb484910111::VectorReductionPattern259d137c05fSLei Zhang matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
260d137c05fSLei Zhang ConversionPatternRewriter &rewriter) const override {
261d137c05fSLei Zhang Type resultType = typeConverter->convertType(reduceOp.getType());
262d137c05fSLei Zhang if (!resultType)
263d137c05fSLei Zhang return failure();
264d137c05fSLei Zhang
265d137c05fSLei Zhang auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
266d137c05fSLei Zhang if (!srcVectorType || srcVectorType.getRank() != 1)
267d137c05fSLei Zhang return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
268d137c05fSLei Zhang
269d137c05fSLei Zhang // Extract all elements.
270d137c05fSLei Zhang int numElements = srcVectorType.getDimSize(0);
271d137c05fSLei Zhang SmallVector<Value, 4> values;
272d137c05fSLei Zhang values.reserve(numElements + (adaptor.getAcc() != nullptr));
273d137c05fSLei Zhang Location loc = reduceOp.getLoc();
274d137c05fSLei Zhang for (int i = 0; i < numElements; ++i) {
275d137c05fSLei Zhang values.push_back(rewriter.create<spirv::CompositeExtractOp>(
276d137c05fSLei Zhang loc, srcVectorType.getElementType(), adaptor.getVector(),
277d137c05fSLei Zhang rewriter.getI32ArrayAttr({i})));
278d137c05fSLei Zhang }
279d137c05fSLei Zhang if (Value acc = adaptor.getAcc())
280d137c05fSLei Zhang values.push_back(acc);
281d137c05fSLei Zhang
282d137c05fSLei Zhang // Reduce them.
283d137c05fSLei Zhang Value result = values.front();
284d137c05fSLei Zhang for (Value next : llvm::makeArrayRef(values).drop_front()) {
285d137c05fSLei Zhang switch (reduceOp.getKind()) {
286d137c05fSLei Zhang #define INT_FLOAT_CASE(kind, iop, fop) \
287d137c05fSLei Zhang case vector::CombiningKind::kind: \
288d137c05fSLei Zhang if (resultType.isa<IntegerType>()) { \
289d137c05fSLei Zhang result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
290d137c05fSLei Zhang } else { \
291d137c05fSLei Zhang assert(resultType.isa<FloatType>()); \
292d137c05fSLei Zhang result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
293d137c05fSLei Zhang } \
294d137c05fSLei Zhang break
295d137c05fSLei Zhang
296d137c05fSLei Zhang INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
297d137c05fSLei Zhang INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
298d137c05fSLei Zhang
299d137c05fSLei Zhang case vector::CombiningKind::MINUI:
300d137c05fSLei Zhang case vector::CombiningKind::MINSI:
301d137c05fSLei Zhang case vector::CombiningKind::MINF:
302d137c05fSLei Zhang case vector::CombiningKind::MAXUI:
303d137c05fSLei Zhang case vector::CombiningKind::MAXSI:
304d137c05fSLei Zhang case vector::CombiningKind::MAXF:
305d137c05fSLei Zhang case vector::CombiningKind::AND:
306d137c05fSLei Zhang case vector::CombiningKind::OR:
307d137c05fSLei Zhang case vector::CombiningKind::XOR:
308d137c05fSLei Zhang return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
309d137c05fSLei Zhang }
310d137c05fSLei Zhang }
311d137c05fSLei Zhang
312d137c05fSLei Zhang rewriter.replaceOp(reduceOp, result);
313d137c05fSLei Zhang return success();
314d137c05fSLei Zhang }
315d137c05fSLei Zhang };
316d137c05fSLei Zhang
3176a8ba318SRiver Riddle class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
3186a8ba318SRiver Riddle public:
3196a8ba318SRiver Riddle using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
3206a8ba318SRiver Riddle
3216a8ba318SRiver Riddle LogicalResult
matchAndRewrite(vector::SplatOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3226a8ba318SRiver Riddle matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
3236a8ba318SRiver Riddle ConversionPatternRewriter &rewriter) const override {
324a4360efbSLei Zhang Type dstType = getTypeConverter()->convertType(op.getType());
325a4360efbSLei Zhang if (!dstType)
3266a8ba318SRiver Riddle return failure();
327a4360efbSLei Zhang if (dstType.isa<spirv::ScalarType>()) {
328a4360efbSLei Zhang rewriter.replaceOp(op, adaptor.getInput());
329a4360efbSLei Zhang } else {
330a4360efbSLei Zhang auto dstVecType = dstType.cast<VectorType>();
3317c38fd60SJacques Pienaar SmallVector<Value, 4> source(dstVecType.getNumElements(),
3327c38fd60SJacques Pienaar adaptor.getInput());
333a4360efbSLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
3346a8ba318SRiver Riddle source);
335a4360efbSLei Zhang }
3366a8ba318SRiver Riddle return success();
3376a8ba318SRiver Riddle }
3386a8ba318SRiver Riddle };
3396a8ba318SRiver Riddle
34036e68c11SLei Zhang struct VectorShuffleOpConvert final
34136e68c11SLei Zhang : public OpConversionPattern<vector::ShuffleOp> {
34236e68c11SLei Zhang using OpConversionPattern::OpConversionPattern;
34336e68c11SLei Zhang
34436e68c11SLei Zhang LogicalResult
matchAndRewrite__anon0bb484910111::VectorShuffleOpConvert34536e68c11SLei Zhang matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
34636e68c11SLei Zhang ConversionPatternRewriter &rewriter) const override {
34736e68c11SLei Zhang auto oldResultType = shuffleOp.getVectorType();
34836e68c11SLei Zhang if (!spirv::CompositeType::isValid(oldResultType))
34936e68c11SLei Zhang return failure();
35036e68c11SLei Zhang auto newResultType = getTypeConverter()->convertType(oldResultType);
35136e68c11SLei Zhang
35236e68c11SLei Zhang auto oldSourceType = shuffleOp.getV1VectorType();
35336e68c11SLei Zhang if (oldSourceType.getNumElements() > 1) {
35436e68c11SLei Zhang SmallVector<int32_t, 4> components = llvm::to_vector<4>(
3557c38fd60SJacques Pienaar llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
35636e68c11SLei Zhang return attr.cast<IntegerAttr>().getValue().getZExtValue();
35736e68c11SLei Zhang }));
35836e68c11SLei Zhang rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
3597c38fd60SJacques Pienaar shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
36036e68c11SLei Zhang rewriter.getI32ArrayAttr(components));
36136e68c11SLei Zhang return success();
36236e68c11SLei Zhang }
36336e68c11SLei Zhang
3647c38fd60SJacques Pienaar SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
36536e68c11SLei Zhang SmallVector<Value, 4> newOperands;
36636e68c11SLei Zhang newOperands.reserve(oldResultType.getNumElements());
3677c38fd60SJacques Pienaar for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
36836e68c11SLei Zhang newOperands.push_back(oldOperands[i.getZExtValue()]);
36936e68c11SLei Zhang }
37036e68c11SLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
37136e68c11SLei Zhang shuffleOp, newResultType, newOperands);
37236e68c11SLei Zhang
37336e68c11SLei Zhang return success();
37436e68c11SLei Zhang }
37536e68c11SLei Zhang };
37636e68c11SLei Zhang
3776e557bc4SThomas Raoux } // namespace
3786e557bc4SThomas Raoux
populateVectorToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)3793a506b31SChris Lattner void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
380dc4e913bSChris Lattner RewritePatternSet &patterns) {
381dc4e913bSChris Lattner patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
3829f622b3dSLei Zhang VectorExtractElementOpConvert, VectorExtractOpConvert,
3839f622b3dSLei Zhang VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
3849f622b3dSLei Zhang VectorInsertElementOpConvert, VectorInsertOpConvert,
385d137c05fSLei Zhang VectorReductionPattern, VectorInsertStridedSliceOpConvert,
386d137c05fSLei Zhang VectorShuffleOpConvert, VectorSplatPattern>(
387d137c05fSLei Zhang typeConverter, patterns.getContext());
3886e557bc4SThomas Raoux }
389