1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include <numeric>
25 
26 using namespace mlir;
27 
28 /// Gets the first integer value from `attr`, assuming it is an integer array
29 /// attribute.
30 static uint64_t getFirstIntValue(ArrayAttr attr) {
31   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
32 }
33 
34 namespace {
35 
36 struct VectorBitcastConvert final
37     : public OpConversionPattern<vector::BitCastOp> {
38   using OpConversionPattern::OpConversionPattern;
39 
40   LogicalResult
41   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
42                   ConversionPatternRewriter &rewriter) const override {
43     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
44     if (!dstType)
45       return failure();
46 
47     if (dstType == adaptor.source().getType())
48       rewriter.replaceOp(bitcastOp, adaptor.source());
49     else
50       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
51                                                     adaptor.source());
52 
53     return success();
54   }
55 };
56 
57 struct VectorBroadcastConvert final
58     : public OpConversionPattern<vector::BroadcastOp> {
59   using OpConversionPattern::OpConversionPattern;
60 
61   LogicalResult
62   matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
63                   ConversionPatternRewriter &rewriter) const override {
64     if (broadcastOp.source().getType().isa<VectorType>() ||
65         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
66       return failure();
67     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
68                                  adaptor.source());
69     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
70         broadcastOp, broadcastOp.getVectorType(), source);
71     return success();
72   }
73 };
74 
75 struct VectorExtractOpConvert final
76     : public OpConversionPattern<vector::ExtractOp> {
77   using OpConversionPattern::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
81                   ConversionPatternRewriter &rewriter) const override {
82     // Only support extracting a scalar value now.
83     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
84     if (resultVectorType && resultVectorType.getNumElements() > 1)
85       return failure();
86 
87     auto dstType = getTypeConverter()->convertType(extractOp.getType());
88     if (!dstType)
89       return failure();
90 
91     if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
92       rewriter.replaceOp(extractOp, adaptor.vector());
93       return success();
94     }
95 
96     int32_t id = getFirstIntValue(extractOp.position());
97     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
98         extractOp, adaptor.vector(), id);
99     return success();
100   }
101 };
102 
103 struct VectorExtractStridedSliceOpConvert final
104     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
105   using OpConversionPattern::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override {
110     auto dstType = getTypeConverter()->convertType(extractOp.getType());
111     if (!dstType)
112       return failure();
113 
114 
115     uint64_t offset = getFirstIntValue(extractOp.offsets());
116     uint64_t size = getFirstIntValue(extractOp.sizes());
117     uint64_t stride = getFirstIntValue(extractOp.strides());
118     if (stride != 1)
119       return failure();
120 
121     Value srcVector = adaptor.getOperands().front();
122 
123     // Extract vector<1xT> case.
124     if (dstType.isa<spirv::ScalarType>()) {
125       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
126                                                              srcVector, offset);
127       return success();
128     }
129 
130     SmallVector<int32_t, 2> indices(size);
131     std::iota(indices.begin(), indices.end(), offset);
132 
133     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
134         extractOp, dstType, srcVector, srcVector,
135         rewriter.getI32ArrayAttr(indices));
136 
137     return success();
138   }
139 };
140 
141 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
142   using OpConversionPattern::OpConversionPattern;
143 
144   LogicalResult
145   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
146                   ConversionPatternRewriter &rewriter) const override {
147     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
148       return failure();
149     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
150         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
151     return success();
152   }
153 };
154 
155 struct VectorInsertOpConvert final
156     : public OpConversionPattern<vector::InsertOp> {
157   using OpConversionPattern::OpConversionPattern;
158 
159   LogicalResult
160   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
161                   ConversionPatternRewriter &rewriter) const override {
162     // Special case for inserting scalar values into size-1 vectors.
163     if (insertOp.getSourceType().isIntOrFloat() &&
164         insertOp.getDestVectorType().getNumElements() == 1) {
165       rewriter.replaceOp(insertOp, adaptor.source());
166       return success();
167     }
168 
169     if (insertOp.getSourceType().isa<VectorType>() ||
170         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
171       return failure();
172     int32_t id = getFirstIntValue(insertOp.position());
173     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
174         insertOp, adaptor.source(), adaptor.dest(), id);
175     return success();
176   }
177 };
178 
179 struct VectorExtractElementOpConvert final
180     : public OpConversionPattern<vector::ExtractElementOp> {
181   using OpConversionPattern::OpConversionPattern;
182 
183   LogicalResult
184   matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
185                   ConversionPatternRewriter &rewriter) const override {
186     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
187       return failure();
188     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
189         extractElementOp, extractElementOp.getType(), adaptor.vector(),
190         extractElementOp.position());
191     return success();
192   }
193 };
194 
195 struct VectorInsertElementOpConvert final
196     : public OpConversionPattern<vector::InsertElementOp> {
197   using OpConversionPattern::OpConversionPattern;
198 
199   LogicalResult
200   matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
201                   ConversionPatternRewriter &rewriter) const override {
202     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
203       return failure();
204     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
205         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
206         adaptor.source(), insertElementOp.position());
207     return success();
208   }
209 };
210 
211 struct VectorInsertStridedSliceOpConvert final
212     : public OpConversionPattern<vector::InsertStridedSliceOp> {
213   using OpConversionPattern::OpConversionPattern;
214 
215   LogicalResult
216   matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
217                   ConversionPatternRewriter &rewriter) const override {
218     Value srcVector = adaptor.getOperands().front();
219     Value dstVector = adaptor.getOperands().back();
220 
221     uint64_t stride = getFirstIntValue(insertOp.strides());
222     if (stride != 1)
223       return failure();
224     uint64_t offset = getFirstIntValue(insertOp.offsets());
225 
226     if (srcVector.getType().isa<spirv::ScalarType>()) {
227       assert(!dstVector.getType().isa<spirv::ScalarType>());
228       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
229           insertOp, dstVector.getType(), srcVector, dstVector,
230           rewriter.getI32ArrayAttr(offset));
231       return success();
232     }
233 
234     uint64_t totalSize =
235         dstVector.getType().cast<VectorType>().getNumElements();
236     uint64_t insertSize =
237         srcVector.getType().cast<VectorType>().getNumElements();
238 
239     SmallVector<int32_t, 2> indices(totalSize);
240     std::iota(indices.begin(), indices.end(), 0);
241     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
242               totalSize);
243 
244     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
245         insertOp, dstVector.getType(), dstVector, srcVector,
246         rewriter.getI32ArrayAttr(indices));
247 
248     return success();
249   }
250 };
251 
252 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
253 public:
254   using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
255 
256   LogicalResult
257   matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
258                   ConversionPatternRewriter &rewriter) const override {
259     VectorType dstVecType = op.getType();
260     if (!spirv::CompositeType::isValid(dstVecType))
261       return failure();
262     SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
263     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
264                                                              source);
265     return success();
266   }
267 };
268 
269 struct VectorShuffleOpConvert final
270     : public OpConversionPattern<vector::ShuffleOp> {
271   using OpConversionPattern::OpConversionPattern;
272 
273   LogicalResult
274   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
275                   ConversionPatternRewriter &rewriter) const override {
276     auto oldResultType = shuffleOp.getVectorType();
277     if (!spirv::CompositeType::isValid(oldResultType))
278       return failure();
279     auto newResultType = getTypeConverter()->convertType(oldResultType);
280 
281     auto oldSourceType = shuffleOp.getV1VectorType();
282     if (oldSourceType.getNumElements() > 1) {
283       SmallVector<int32_t, 4> components = llvm::to_vector<4>(
284           llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t {
285             return attr.cast<IntegerAttr>().getValue().getZExtValue();
286           }));
287       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
288           shuffleOp, newResultType, adaptor.v1(), adaptor.v2(),
289           rewriter.getI32ArrayAttr(components));
290       return success();
291     }
292 
293     SmallVector<Value, 2> oldOperands = {adaptor.v1(), adaptor.v2()};
294     SmallVector<Value, 4> newOperands;
295     newOperands.reserve(oldResultType.getNumElements());
296     for (const APInt &i : shuffleOp.mask().getAsValueRange<IntegerAttr>()) {
297       newOperands.push_back(oldOperands[i.getZExtValue()]);
298     }
299     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
300         shuffleOp, newResultType, newOperands);
301 
302     return success();
303   }
304 };
305 
306 } // namespace
307 
308 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
309                                          RewritePatternSet &patterns) {
310   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
311                VectorExtractElementOpConvert, VectorExtractOpConvert,
312                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
313                VectorInsertElementOpConvert, VectorInsertOpConvert,
314                VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
315                VectorSplatPattern>(typeConverter, patterns.getContext());
316 }
317