1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 /// Gets the first integer value from `attr`, assuming it is an integer array
29 /// attribute.
30 static uint64_t getFirstIntValue(ArrayAttr attr) {
31   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
32 }
33 
34 namespace {
35 
36 struct VectorBitcastConvert final
37     : public OpConversionPattern<vector::BitCastOp> {
38   using OpConversionPattern::OpConversionPattern;
39 
40   LogicalResult
41   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
42                   ConversionPatternRewriter &rewriter) const override {
43     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
44     if (!dstType)
45       return failure();
46 
47     if (dstType == adaptor.getSource().getType())
48       rewriter.replaceOp(bitcastOp, adaptor.getSource());
49     else
50       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
51                                                     adaptor.getSource());
52 
53     return success();
54   }
55 };
56 
57 struct VectorBroadcastConvert final
58     : public OpConversionPattern<vector::BroadcastOp> {
59   using OpConversionPattern::OpConversionPattern;
60 
61   LogicalResult
62   matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
63                   ConversionPatternRewriter &rewriter) const override {
64     if (broadcastOp.getSource().getType().isa<VectorType>() ||
65         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
66       return failure();
67     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
68                                  adaptor.getSource());
69     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
70         broadcastOp, broadcastOp.getVectorType(), source);
71     return success();
72   }
73 };
74 
75 struct VectorExtractOpConvert final
76     : public OpConversionPattern<vector::ExtractOp> {
77   using OpConversionPattern::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
81                   ConversionPatternRewriter &rewriter) const override {
82     // Only support extracting a scalar value now.
83     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
84     if (resultVectorType && resultVectorType.getNumElements() > 1)
85       return failure();
86 
87     auto dstType = getTypeConverter()->convertType(extractOp.getType());
88     if (!dstType)
89       return failure();
90 
91     if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
92       rewriter.replaceOp(extractOp, adaptor.getVector());
93       return success();
94     }
95 
96     int32_t id = getFirstIntValue(extractOp.getPosition());
97     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
98         extractOp, adaptor.getVector(), id);
99     return success();
100   }
101 };
102 
103 struct VectorExtractStridedSliceOpConvert final
104     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
105   using OpConversionPattern::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override {
110     auto dstType = getTypeConverter()->convertType(extractOp.getType());
111     if (!dstType)
112       return failure();
113 
114     uint64_t offset = getFirstIntValue(extractOp.getOffsets());
115     uint64_t size = getFirstIntValue(extractOp.getSizes());
116     uint64_t stride = getFirstIntValue(extractOp.getStrides());
117     if (stride != 1)
118       return failure();
119 
120     Value srcVector = adaptor.getOperands().front();
121 
122     // Extract vector<1xT> case.
123     if (dstType.isa<spirv::ScalarType>()) {
124       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
125                                                              srcVector, offset);
126       return success();
127     }
128 
129     SmallVector<int32_t, 2> indices(size);
130     std::iota(indices.begin(), indices.end(), offset);
131 
132     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
133         extractOp, dstType, srcVector, srcVector,
134         rewriter.getI32ArrayAttr(indices));
135 
136     return success();
137   }
138 };
139 
140 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
141   using OpConversionPattern::OpConversionPattern;
142 
143   LogicalResult
144   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
145                   ConversionPatternRewriter &rewriter) const override {
146     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
147       return failure();
148     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
149         fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(),
150         adaptor.getAcc());
151     return success();
152   }
153 };
154 
155 struct VectorInsertOpConvert final
156     : public OpConversionPattern<vector::InsertOp> {
157   using OpConversionPattern::OpConversionPattern;
158 
159   LogicalResult
160   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
161                   ConversionPatternRewriter &rewriter) const override {
162     // Special case for inserting scalar values into size-1 vectors.
163     if (insertOp.getSourceType().isIntOrFloat() &&
164         insertOp.getDestVectorType().getNumElements() == 1) {
165       rewriter.replaceOp(insertOp, adaptor.getSource());
166       return success();
167     }
168 
169     if (insertOp.getSourceType().isa<VectorType>() ||
170         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
171       return failure();
172     int32_t id = getFirstIntValue(insertOp.getPosition());
173     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
174         insertOp, adaptor.getSource(), adaptor.getDest(), id);
175     return success();
176   }
177 };
178 
179 struct VectorExtractElementOpConvert final
180     : public OpConversionPattern<vector::ExtractElementOp> {
181   using OpConversionPattern::OpConversionPattern;
182 
183   LogicalResult
184   matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
185                   ConversionPatternRewriter &rewriter) const override {
186     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
187       return failure();
188     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
189         extractElementOp, extractElementOp.getType(), adaptor.getVector(),
190         extractElementOp.getPosition());
191     return success();
192   }
193 };
194 
195 struct VectorInsertElementOpConvert final
196     : public OpConversionPattern<vector::InsertElementOp> {
197   using OpConversionPattern::OpConversionPattern;
198 
199   LogicalResult
200   matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
201                   ConversionPatternRewriter &rewriter) const override {
202     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
203       return failure();
204     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
205         insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
206         adaptor.getSource(), insertElementOp.getPosition());
207     return success();
208   }
209 };
210 
211 struct VectorInsertStridedSliceOpConvert final
212     : public OpConversionPattern<vector::InsertStridedSliceOp> {
213   using OpConversionPattern::OpConversionPattern;
214 
215   LogicalResult
216   matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
217                   ConversionPatternRewriter &rewriter) const override {
218     Value srcVector = adaptor.getOperands().front();
219     Value dstVector = adaptor.getOperands().back();
220 
221     uint64_t stride = getFirstIntValue(insertOp.getStrides());
222     if (stride != 1)
223       return failure();
224     uint64_t offset = getFirstIntValue(insertOp.getOffsets());
225 
226     if (srcVector.getType().isa<spirv::ScalarType>()) {
227       assert(!dstVector.getType().isa<spirv::ScalarType>());
228       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
229           insertOp, dstVector.getType(), srcVector, dstVector,
230           rewriter.getI32ArrayAttr(offset));
231       return success();
232     }
233 
234     uint64_t totalSize =
235         dstVector.getType().cast<VectorType>().getNumElements();
236     uint64_t insertSize =
237         srcVector.getType().cast<VectorType>().getNumElements();
238 
239     SmallVector<int32_t, 2> indices(totalSize);
240     std::iota(indices.begin(), indices.end(), 0);
241     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
242               totalSize);
243 
244     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
245         insertOp, dstVector.getType(), dstVector, srcVector,
246         rewriter.getI32ArrayAttr(indices));
247 
248     return success();
249   }
250 };
251 
252 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
253 public:
254   using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
255 
256   LogicalResult
257   matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
258                   ConversionPatternRewriter &rewriter) const override {
259     VectorType dstVecType = op.getType();
260     if (!spirv::CompositeType::isValid(dstVecType))
261       return failure();
262     SmallVector<Value, 4> source(dstVecType.getNumElements(),
263                                  adaptor.getInput());
264     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
265                                                              source);
266     return success();
267   }
268 };
269 
270 struct VectorShuffleOpConvert final
271     : public OpConversionPattern<vector::ShuffleOp> {
272   using OpConversionPattern::OpConversionPattern;
273 
274   LogicalResult
275   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
276                   ConversionPatternRewriter &rewriter) const override {
277     auto oldResultType = shuffleOp.getVectorType();
278     if (!spirv::CompositeType::isValid(oldResultType))
279       return failure();
280     auto newResultType = getTypeConverter()->convertType(oldResultType);
281 
282     auto oldSourceType = shuffleOp.getV1VectorType();
283     if (oldSourceType.getNumElements() > 1) {
284       SmallVector<int32_t, 4> components = llvm::to_vector<4>(
285           llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
286             return attr.cast<IntegerAttr>().getValue().getZExtValue();
287           }));
288       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
289           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
290           rewriter.getI32ArrayAttr(components));
291       return success();
292     }
293 
294     SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
295     SmallVector<Value, 4> newOperands;
296     newOperands.reserve(oldResultType.getNumElements());
297     for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
298       newOperands.push_back(oldOperands[i.getZExtValue()]);
299     }
300     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
301         shuffleOp, newResultType, newOperands);
302 
303     return success();
304   }
305 };
306 
307 } // namespace
308 
309 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
310                                          RewritePatternSet &patterns) {
311   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
312                VectorExtractElementOpConvert, VectorExtractOpConvert,
313                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
314                VectorInsertElementOpConvert, VectorInsertOpConvert,
315                VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
316                VectorSplatPattern>(typeConverter, patterns.getContext());
317 }
318