1930c74f1SLei Zhang //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
26e557bc4SThomas Raoux //
36e557bc4SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46e557bc4SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
56e557bc4SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66e557bc4SThomas Raoux //
76e557bc4SThomas Raoux //===----------------------------------------------------------------------===//
86e557bc4SThomas Raoux //
9930c74f1SLei Zhang // This file implements patterns to convert Vector dialect to SPIRV dialect.
106e557bc4SThomas Raoux //
116e557bc4SThomas Raoux //===----------------------------------------------------------------------===//
126e557bc4SThomas Raoux 
13930c74f1SLei Zhang #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14930c74f1SLei Zhang 
156e557bc4SThomas Raoux #include "../PassDetail.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1701178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1901178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2136e68c11SLei Zhang #include "mlir/IR/BuiltinAttributes.h"
22d137c05fSLei Zhang #include "mlir/IR/BuiltinTypes.h"
236e557bc4SThomas Raoux #include "mlir/Transforms/DialectConversion.h"
24d137c05fSLei Zhang #include "llvm/ADT/ArrayRef.h"
2536e68c11SLei Zhang #include "llvm/ADT/STLExtras.h"
269f622b3dSLei Zhang #include <numeric>
276e557bc4SThomas Raoux 
286e557bc4SThomas Raoux using namespace mlir;
296e557bc4SThomas Raoux 
309f622b3dSLei Zhang /// Gets the first integer value from `attr`, assuming it is an integer array
319f622b3dSLei Zhang /// attribute.
getFirstIntValue(ArrayAttr attr)329f622b3dSLei Zhang static uint64_t getFirstIntValue(ArrayAttr attr) {
339f622b3dSLei Zhang   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
34801067f4SStella Stamenova }
359f622b3dSLei Zhang 
366e557bc4SThomas Raoux namespace {
379f622b3dSLei Zhang 
389f622b3dSLei Zhang struct VectorBitcastConvert final
399f622b3dSLei Zhang     : public OpConversionPattern<vector::BitCastOp> {
409f622b3dSLei Zhang   using OpConversionPattern::OpConversionPattern;
419f622b3dSLei Zhang 
429f622b3dSLei Zhang   LogicalResult
matchAndRewrite__anon0bb484910111::VectorBitcastConvert43b54c724bSRiver Riddle   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
449f622b3dSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
459f622b3dSLei Zhang     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
469f622b3dSLei Zhang     if (!dstType)
479f622b3dSLei Zhang       return failure();
489f622b3dSLei Zhang 
497c38fd60SJacques Pienaar     if (dstType == adaptor.getSource().getType())
507c38fd60SJacques Pienaar       rewriter.replaceOp(bitcastOp, adaptor.getSource());
519f622b3dSLei Zhang     else
529f622b3dSLei Zhang       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
537c38fd60SJacques Pienaar                                                     adaptor.getSource());
549f622b3dSLei Zhang 
559f622b3dSLei Zhang     return success();
569f622b3dSLei Zhang   }
579f622b3dSLei Zhang };
589f622b3dSLei Zhang 
596e557bc4SThomas Raoux struct VectorBroadcastConvert final
607c3ae48fSLei Zhang     : public OpConversionPattern<vector::BroadcastOp> {
617c3ae48fSLei Zhang   using OpConversionPattern::OpConversionPattern;
627c3ae48fSLei Zhang 
636e557bc4SThomas Raoux   LogicalResult
matchAndRewrite__anon0bb484910111::VectorBroadcastConvert64b54c724bSRiver Riddle   matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
656e557bc4SThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
667c38fd60SJacques Pienaar     if (broadcastOp.getSource().getType().isa<VectorType>() ||
676e557bc4SThomas Raoux         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
686e557bc4SThomas Raoux       return failure();
696e557bc4SThomas Raoux     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
707c38fd60SJacques Pienaar                                  adaptor.getSource());
713a56a966SLei Zhang     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
723a56a966SLei Zhang         broadcastOp, broadcastOp.getVectorType(), source);
736e557bc4SThomas Raoux     return success();
746e557bc4SThomas Raoux   }
756e557bc4SThomas Raoux };
766e557bc4SThomas Raoux 
776e557bc4SThomas Raoux struct VectorExtractOpConvert final
787c3ae48fSLei Zhang     : public OpConversionPattern<vector::ExtractOp> {
797c3ae48fSLei Zhang   using OpConversionPattern::OpConversionPattern;
807c3ae48fSLei Zhang 
816e557bc4SThomas Raoux   LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractOpConvert82b54c724bSRiver Riddle   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
836e557bc4SThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
849f622b3dSLei Zhang     // Only support extracting a scalar value now.
859f622b3dSLei Zhang     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
869f622b3dSLei Zhang     if (resultVectorType && resultVectorType.getNumElements() > 1)
876e557bc4SThomas Raoux       return failure();
889f622b3dSLei Zhang 
899f622b3dSLei Zhang     auto dstType = getTypeConverter()->convertType(extractOp.getType());
909f622b3dSLei Zhang     if (!dstType)
919f622b3dSLei Zhang       return failure();
929f622b3dSLei Zhang 
937c38fd60SJacques Pienaar     if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
947c38fd60SJacques Pienaar       rewriter.replaceOp(extractOp, adaptor.getVector());
95b2e72cd3Sthomasraoux       return success();
96b2e72cd3Sthomasraoux     }
97b2e72cd3Sthomasraoux 
987c38fd60SJacques Pienaar     int32_t id = getFirstIntValue(extractOp.getPosition());
993a56a966SLei Zhang     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
1007c38fd60SJacques Pienaar         extractOp, adaptor.getVector(), id);
1013a56a966SLei Zhang     return success();
1023a56a966SLei Zhang   }
1033a56a966SLei Zhang };
1043a56a966SLei Zhang 
1059f622b3dSLei Zhang struct VectorExtractStridedSliceOpConvert final
1069f622b3dSLei Zhang     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
1079f622b3dSLei Zhang   using OpConversionPattern::OpConversionPattern;
1089f622b3dSLei Zhang 
1099f622b3dSLei Zhang   LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractStridedSliceOpConvert110b54c724bSRiver Riddle   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
1119f622b3dSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
1129f622b3dSLei Zhang     auto dstType = getTypeConverter()->convertType(extractOp.getType());
1139f622b3dSLei Zhang     if (!dstType)
1149f622b3dSLei Zhang       return failure();
1159f622b3dSLei Zhang 
1167c38fd60SJacques Pienaar     uint64_t offset = getFirstIntValue(extractOp.getOffsets());
1177c38fd60SJacques Pienaar     uint64_t size = getFirstIntValue(extractOp.getSizes());
1187c38fd60SJacques Pienaar     uint64_t stride = getFirstIntValue(extractOp.getStrides());
1199f622b3dSLei Zhang     if (stride != 1)
1209f622b3dSLei Zhang       return failure();
1219f622b3dSLei Zhang 
122b54c724bSRiver Riddle     Value srcVector = adaptor.getOperands().front();
1239f622b3dSLei Zhang 
124565ee6afSthomasraoux     // Extract vector<1xT> case.
125565ee6afSthomasraoux     if (dstType.isa<spirv::ScalarType>()) {
126565ee6afSthomasraoux       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
127565ee6afSthomasraoux                                                              srcVector, offset);
128565ee6afSthomasraoux       return success();
129565ee6afSthomasraoux     }
130565ee6afSthomasraoux 
1319f622b3dSLei Zhang     SmallVector<int32_t, 2> indices(size);
1329f622b3dSLei Zhang     std::iota(indices.begin(), indices.end(), offset);
1339f622b3dSLei Zhang 
1349f622b3dSLei Zhang     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
1359f622b3dSLei Zhang         extractOp, dstType, srcVector, srcVector,
1369f622b3dSLei Zhang         rewriter.getI32ArrayAttr(indices));
1379f622b3dSLei Zhang 
1389f622b3dSLei Zhang     return success();
1399f622b3dSLei Zhang   }
1409f622b3dSLei Zhang };
1419f622b3dSLei Zhang 
1423a56a966SLei Zhang struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
1433a56a966SLei Zhang   using OpConversionPattern::OpConversionPattern;
1443a56a966SLei Zhang 
1453a56a966SLei Zhang   LogicalResult
matchAndRewrite__anon0bb484910111::VectorFmaOpConvert146b54c724bSRiver Riddle   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1473a56a966SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
148a4360efbSLei Zhang     Type dstType = getTypeConverter()->convertType(fmaOp.getType());
149a4360efbSLei Zhang     if (!dstType)
1503a56a966SLei Zhang       return failure();
151*52b630daSJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::GLFmaOp>(
152a4360efbSLei Zhang         fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1536e557bc4SThomas Raoux     return success();
1546e557bc4SThomas Raoux   }
1556e557bc4SThomas Raoux };
1566e557bc4SThomas Raoux 
1577c3ae48fSLei Zhang struct VectorInsertOpConvert final
1587c3ae48fSLei Zhang     : public OpConversionPattern<vector::InsertOp> {
1597c3ae48fSLei Zhang   using OpConversionPattern::OpConversionPattern;
1607c3ae48fSLei Zhang 
1616e557bc4SThomas Raoux   LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertOpConvert162b54c724bSRiver Riddle   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1636e557bc4SThomas Raoux                   ConversionPatternRewriter &rewriter) const override {
16447107508SLei Zhang     // Special case for inserting scalar values into size-1 vectors.
16547107508SLei Zhang     if (insertOp.getSourceType().isIntOrFloat() &&
16647107508SLei Zhang         insertOp.getDestVectorType().getNumElements() == 1) {
1677c38fd60SJacques Pienaar       rewriter.replaceOp(insertOp, adaptor.getSource());
16847107508SLei Zhang       return success();
16947107508SLei Zhang     }
17047107508SLei Zhang 
1716e557bc4SThomas Raoux     if (insertOp.getSourceType().isa<VectorType>() ||
1726e557bc4SThomas Raoux         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
1736e557bc4SThomas Raoux       return failure();
1747c38fd60SJacques Pienaar     int32_t id = getFirstIntValue(insertOp.getPosition());
1753a56a966SLei Zhang     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
1767c38fd60SJacques Pienaar         insertOp, adaptor.getSource(), adaptor.getDest(), id);
1776e557bc4SThomas Raoux     return success();
1786e557bc4SThomas Raoux   }
1796e557bc4SThomas Raoux };
180f9dca103SArtur Bialas 
181f9dca103SArtur Bialas struct VectorExtractElementOpConvert final
1827c3ae48fSLei Zhang     : public OpConversionPattern<vector::ExtractElementOp> {
1837c3ae48fSLei Zhang   using OpConversionPattern::OpConversionPattern;
1847c3ae48fSLei Zhang 
185f9dca103SArtur Bialas   LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractElementOpConvert186b54c724bSRiver Riddle   matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
187f9dca103SArtur Bialas                   ConversionPatternRewriter &rewriter) const override {
188f9dca103SArtur Bialas     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
189f9dca103SArtur Bialas       return failure();
1903a56a966SLei Zhang     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
1917c38fd60SJacques Pienaar         extractElementOp, extractElementOp.getType(), adaptor.getVector(),
1927c38fd60SJacques Pienaar         extractElementOp.getPosition());
193f9dca103SArtur Bialas     return success();
194f9dca103SArtur Bialas   }
195f9dca103SArtur Bialas };
196f9dca103SArtur Bialas 
1973035e676SArtur Bialas struct VectorInsertElementOpConvert final
1987c3ae48fSLei Zhang     : public OpConversionPattern<vector::InsertElementOp> {
1997c3ae48fSLei Zhang   using OpConversionPattern::OpConversionPattern;
2007c3ae48fSLei Zhang 
2013035e676SArtur Bialas   LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertElementOpConvert202b54c724bSRiver Riddle   matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
2033035e676SArtur Bialas                   ConversionPatternRewriter &rewriter) const override {
2043035e676SArtur Bialas     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
2053035e676SArtur Bialas       return failure();
2063a56a966SLei Zhang     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
2077c38fd60SJacques Pienaar         insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
2087c38fd60SJacques Pienaar         adaptor.getSource(), insertElementOp.getPosition());
2093035e676SArtur Bialas     return success();
2103035e676SArtur Bialas   }
2113035e676SArtur Bialas };
2123035e676SArtur Bialas 
2139f622b3dSLei Zhang struct VectorInsertStridedSliceOpConvert final
2149f622b3dSLei Zhang     : public OpConversionPattern<vector::InsertStridedSliceOp> {
2159f622b3dSLei Zhang   using OpConversionPattern::OpConversionPattern;
2169f622b3dSLei Zhang 
2179f622b3dSLei Zhang   LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertStridedSliceOpConvert218b54c724bSRiver Riddle   matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
2199f622b3dSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
220b54c724bSRiver Riddle     Value srcVector = adaptor.getOperands().front();
221b54c724bSRiver Riddle     Value dstVector = adaptor.getOperands().back();
2229f622b3dSLei Zhang 
2237c38fd60SJacques Pienaar     uint64_t stride = getFirstIntValue(insertOp.getStrides());
2249f622b3dSLei Zhang     if (stride != 1)
2259f622b3dSLei Zhang       return failure();
2267c38fd60SJacques Pienaar     uint64_t offset = getFirstIntValue(insertOp.getOffsets());
22747107508SLei Zhang 
22847107508SLei Zhang     if (srcVector.getType().isa<spirv::ScalarType>()) {
22947107508SLei Zhang       assert(!dstVector.getType().isa<spirv::ScalarType>());
23047107508SLei Zhang       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
23147107508SLei Zhang           insertOp, dstVector.getType(), srcVector, dstVector,
23247107508SLei Zhang           rewriter.getI32ArrayAttr(offset));
23347107508SLei Zhang       return success();
23447107508SLei Zhang     }
2359f622b3dSLei Zhang 
2369f622b3dSLei Zhang     uint64_t totalSize =
2379f622b3dSLei Zhang         dstVector.getType().cast<VectorType>().getNumElements();
2389f622b3dSLei Zhang     uint64_t insertSize =
2399f622b3dSLei Zhang         srcVector.getType().cast<VectorType>().getNumElements();
2409f622b3dSLei Zhang 
2419f622b3dSLei Zhang     SmallVector<int32_t, 2> indices(totalSize);
2429f622b3dSLei Zhang     std::iota(indices.begin(), indices.end(), 0);
2439f622b3dSLei Zhang     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
2449f622b3dSLei Zhang               totalSize);
2459f622b3dSLei Zhang 
2469f622b3dSLei Zhang     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
2479f622b3dSLei Zhang         insertOp, dstVector.getType(), dstVector, srcVector,
2489f622b3dSLei Zhang         rewriter.getI32ArrayAttr(indices));
2499f622b3dSLei Zhang 
2509f622b3dSLei Zhang     return success();
2519f622b3dSLei Zhang   }
2529f622b3dSLei Zhang };
2539f622b3dSLei Zhang 
254d137c05fSLei Zhang struct VectorReductionPattern final
255d137c05fSLei Zhang     : public OpConversionPattern<vector::ReductionOp> {
256d137c05fSLei Zhang   using OpConversionPattern::OpConversionPattern;
257d137c05fSLei Zhang 
258d137c05fSLei Zhang   LogicalResult
matchAndRewrite__anon0bb484910111::VectorReductionPattern259d137c05fSLei Zhang   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
260d137c05fSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
261d137c05fSLei Zhang     Type resultType = typeConverter->convertType(reduceOp.getType());
262d137c05fSLei Zhang     if (!resultType)
263d137c05fSLei Zhang       return failure();
264d137c05fSLei Zhang 
265d137c05fSLei Zhang     auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
266d137c05fSLei Zhang     if (!srcVectorType || srcVectorType.getRank() != 1)
267d137c05fSLei Zhang       return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
268d137c05fSLei Zhang 
269d137c05fSLei Zhang     // Extract all elements.
270d137c05fSLei Zhang     int numElements = srcVectorType.getDimSize(0);
271d137c05fSLei Zhang     SmallVector<Value, 4> values;
272d137c05fSLei Zhang     values.reserve(numElements + (adaptor.getAcc() != nullptr));
273d137c05fSLei Zhang     Location loc = reduceOp.getLoc();
274d137c05fSLei Zhang     for (int i = 0; i < numElements; ++i) {
275d137c05fSLei Zhang       values.push_back(rewriter.create<spirv::CompositeExtractOp>(
276d137c05fSLei Zhang           loc, srcVectorType.getElementType(), adaptor.getVector(),
277d137c05fSLei Zhang           rewriter.getI32ArrayAttr({i})));
278d137c05fSLei Zhang     }
279d137c05fSLei Zhang     if (Value acc = adaptor.getAcc())
280d137c05fSLei Zhang       values.push_back(acc);
281d137c05fSLei Zhang 
282d137c05fSLei Zhang     // Reduce them.
283d137c05fSLei Zhang     Value result = values.front();
284d137c05fSLei Zhang     for (Value next : llvm::makeArrayRef(values).drop_front()) {
285d137c05fSLei Zhang       switch (reduceOp.getKind()) {
286d137c05fSLei Zhang #define INT_FLOAT_CASE(kind, iop, fop)                                         \
287d137c05fSLei Zhang   case vector::CombiningKind::kind:                                            \
288d137c05fSLei Zhang     if (resultType.isa<IntegerType>()) {                                       \
289d137c05fSLei Zhang       result = rewriter.create<spirv::iop>(loc, resultType, result, next);     \
290d137c05fSLei Zhang     } else {                                                                   \
291d137c05fSLei Zhang       assert(resultType.isa<FloatType>());                                     \
292d137c05fSLei Zhang       result = rewriter.create<spirv::fop>(loc, resultType, result, next);     \
293d137c05fSLei Zhang     }                                                                          \
294d137c05fSLei Zhang     break
295d137c05fSLei Zhang 
296d137c05fSLei Zhang         INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
297d137c05fSLei Zhang         INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
298d137c05fSLei Zhang 
299d137c05fSLei Zhang       case vector::CombiningKind::MINUI:
300d137c05fSLei Zhang       case vector::CombiningKind::MINSI:
301d137c05fSLei Zhang       case vector::CombiningKind::MINF:
302d137c05fSLei Zhang       case vector::CombiningKind::MAXUI:
303d137c05fSLei Zhang       case vector::CombiningKind::MAXSI:
304d137c05fSLei Zhang       case vector::CombiningKind::MAXF:
305d137c05fSLei Zhang       case vector::CombiningKind::AND:
306d137c05fSLei Zhang       case vector::CombiningKind::OR:
307d137c05fSLei Zhang       case vector::CombiningKind::XOR:
308d137c05fSLei Zhang         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
309d137c05fSLei Zhang       }
310d137c05fSLei Zhang     }
311d137c05fSLei Zhang 
312d137c05fSLei Zhang     rewriter.replaceOp(reduceOp, result);
313d137c05fSLei Zhang     return success();
314d137c05fSLei Zhang   }
315d137c05fSLei Zhang };
316d137c05fSLei Zhang 
3176a8ba318SRiver Riddle class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
3186a8ba318SRiver Riddle public:
3196a8ba318SRiver Riddle   using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
3206a8ba318SRiver Riddle 
3216a8ba318SRiver Riddle   LogicalResult
matchAndRewrite(vector::SplatOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3226a8ba318SRiver Riddle   matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
3236a8ba318SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
324a4360efbSLei Zhang     Type dstType = getTypeConverter()->convertType(op.getType());
325a4360efbSLei Zhang     if (!dstType)
3266a8ba318SRiver Riddle       return failure();
327a4360efbSLei Zhang     if (dstType.isa<spirv::ScalarType>()) {
328a4360efbSLei Zhang       rewriter.replaceOp(op, adaptor.getInput());
329a4360efbSLei Zhang     } else {
330a4360efbSLei Zhang       auto dstVecType = dstType.cast<VectorType>();
3317c38fd60SJacques Pienaar       SmallVector<Value, 4> source(dstVecType.getNumElements(),
3327c38fd60SJacques Pienaar                                    adaptor.getInput());
333a4360efbSLei Zhang       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
3346a8ba318SRiver Riddle                                                                source);
335a4360efbSLei Zhang     }
3366a8ba318SRiver Riddle     return success();
3376a8ba318SRiver Riddle   }
3386a8ba318SRiver Riddle };
3396a8ba318SRiver Riddle 
34036e68c11SLei Zhang struct VectorShuffleOpConvert final
34136e68c11SLei Zhang     : public OpConversionPattern<vector::ShuffleOp> {
34236e68c11SLei Zhang   using OpConversionPattern::OpConversionPattern;
34336e68c11SLei Zhang 
34436e68c11SLei Zhang   LogicalResult
matchAndRewrite__anon0bb484910111::VectorShuffleOpConvert34536e68c11SLei Zhang   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
34636e68c11SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
34736e68c11SLei Zhang     auto oldResultType = shuffleOp.getVectorType();
34836e68c11SLei Zhang     if (!spirv::CompositeType::isValid(oldResultType))
34936e68c11SLei Zhang       return failure();
35036e68c11SLei Zhang     auto newResultType = getTypeConverter()->convertType(oldResultType);
35136e68c11SLei Zhang 
35236e68c11SLei Zhang     auto oldSourceType = shuffleOp.getV1VectorType();
35336e68c11SLei Zhang     if (oldSourceType.getNumElements() > 1) {
35436e68c11SLei Zhang       SmallVector<int32_t, 4> components = llvm::to_vector<4>(
3557c38fd60SJacques Pienaar           llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
35636e68c11SLei Zhang             return attr.cast<IntegerAttr>().getValue().getZExtValue();
35736e68c11SLei Zhang           }));
35836e68c11SLei Zhang       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
3597c38fd60SJacques Pienaar           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
36036e68c11SLei Zhang           rewriter.getI32ArrayAttr(components));
36136e68c11SLei Zhang       return success();
36236e68c11SLei Zhang     }
36336e68c11SLei Zhang 
3647c38fd60SJacques Pienaar     SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
36536e68c11SLei Zhang     SmallVector<Value, 4> newOperands;
36636e68c11SLei Zhang     newOperands.reserve(oldResultType.getNumElements());
3677c38fd60SJacques Pienaar     for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
36836e68c11SLei Zhang       newOperands.push_back(oldOperands[i.getZExtValue()]);
36936e68c11SLei Zhang     }
37036e68c11SLei Zhang     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
37136e68c11SLei Zhang         shuffleOp, newResultType, newOperands);
37236e68c11SLei Zhang 
37336e68c11SLei Zhang     return success();
37436e68c11SLei Zhang   }
37536e68c11SLei Zhang };
37636e68c11SLei Zhang 
3776e557bc4SThomas Raoux } // namespace
3786e557bc4SThomas Raoux 
populateVectorToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)3793a506b31SChris Lattner void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
380dc4e913bSChris Lattner                                          RewritePatternSet &patterns) {
381dc4e913bSChris Lattner   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
3829f622b3dSLei Zhang                VectorExtractElementOpConvert, VectorExtractOpConvert,
3839f622b3dSLei Zhang                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
3849f622b3dSLei Zhang                VectorInsertElementOpConvert, VectorInsertOpConvert,
385d137c05fSLei Zhang                VectorReductionPattern, VectorInsertStridedSliceOpConvert,
386d137c05fSLei Zhang                VectorShuffleOpConvert, VectorSplatPattern>(
387d137c05fSLei Zhang       typeConverter, patterns.getContext());
3886e557bc4SThomas Raoux }
389