1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include <numeric>
23 
24 using namespace mlir;
25 
26 /// Gets the first integer value from `attr`, assuming it is an integer array
27 /// attribute.
28 static uint64_t getFirstIntValue(ArrayAttr attr) {
29   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
30 }
31 
32 namespace {
33 
34 struct VectorBitcastConvert final
35     : public OpConversionPattern<vector::BitCastOp> {
36   using OpConversionPattern::OpConversionPattern;
37 
38   LogicalResult
39   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
40                   ConversionPatternRewriter &rewriter) const override {
41     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
42     if (!dstType)
43       return failure();
44 
45     if (dstType == adaptor.source().getType())
46       rewriter.replaceOp(bitcastOp, adaptor.source());
47     else
48       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
49                                                     adaptor.source());
50 
51     return success();
52   }
53 };
54 
55 struct VectorBroadcastConvert final
56     : public OpConversionPattern<vector::BroadcastOp> {
57   using OpConversionPattern::OpConversionPattern;
58 
59   LogicalResult
60   matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
61                   ConversionPatternRewriter &rewriter) const override {
62     if (broadcastOp.source().getType().isa<VectorType>() ||
63         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
64       return failure();
65     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
66                                  adaptor.source());
67     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
68         broadcastOp, broadcastOp.getVectorType(), source);
69     return success();
70   }
71 };
72 
73 struct VectorExtractOpConvert final
74     : public OpConversionPattern<vector::ExtractOp> {
75   using OpConversionPattern::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override {
80     // Only support extracting a scalar value now.
81     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
82     if (resultVectorType && resultVectorType.getNumElements() > 1)
83       return failure();
84 
85     auto dstType = getTypeConverter()->convertType(extractOp.getType());
86     if (!dstType)
87       return failure();
88 
89     if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
90       rewriter.replaceOp(extractOp, adaptor.vector());
91       return success();
92     }
93 
94     int32_t id = getFirstIntValue(extractOp.position());
95     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
96         extractOp, adaptor.vector(), id);
97     return success();
98   }
99 };
100 
101 struct VectorExtractStridedSliceOpConvert final
102     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
103   using OpConversionPattern::OpConversionPattern;
104 
105   LogicalResult
106   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
107                   ConversionPatternRewriter &rewriter) const override {
108     auto dstType = getTypeConverter()->convertType(extractOp.getType());
109     if (!dstType)
110       return failure();
111 
112 
113     uint64_t offset = getFirstIntValue(extractOp.offsets());
114     uint64_t size = getFirstIntValue(extractOp.sizes());
115     uint64_t stride = getFirstIntValue(extractOp.strides());
116     if (stride != 1)
117       return failure();
118 
119     Value srcVector = adaptor.getOperands().front();
120 
121     // Extract vector<1xT> case.
122     if (dstType.isa<spirv::ScalarType>()) {
123       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
124                                                              srcVector, offset);
125       return success();
126     }
127 
128     SmallVector<int32_t, 2> indices(size);
129     std::iota(indices.begin(), indices.end(), offset);
130 
131     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
132         extractOp, dstType, srcVector, srcVector,
133         rewriter.getI32ArrayAttr(indices));
134 
135     return success();
136   }
137 };
138 
139 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
140   using OpConversionPattern::OpConversionPattern;
141 
142   LogicalResult
143   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
144                   ConversionPatternRewriter &rewriter) const override {
145     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
146       return failure();
147     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
148         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
149     return success();
150   }
151 };
152 
153 struct VectorInsertOpConvert final
154     : public OpConversionPattern<vector::InsertOp> {
155   using OpConversionPattern::OpConversionPattern;
156 
157   LogicalResult
158   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override {
160     // Special case for inserting scalar values into size-1 vectors.
161     if (insertOp.getSourceType().isIntOrFloat() &&
162         insertOp.getDestVectorType().getNumElements() == 1) {
163       rewriter.replaceOp(insertOp, adaptor.source());
164       return success();
165     }
166 
167     if (insertOp.getSourceType().isa<VectorType>() ||
168         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
169       return failure();
170     int32_t id = getFirstIntValue(insertOp.position());
171     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
172         insertOp, adaptor.source(), adaptor.dest(), id);
173     return success();
174   }
175 };
176 
177 struct VectorExtractElementOpConvert final
178     : public OpConversionPattern<vector::ExtractElementOp> {
179   using OpConversionPattern::OpConversionPattern;
180 
181   LogicalResult
182   matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
183                   ConversionPatternRewriter &rewriter) const override {
184     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
185       return failure();
186     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
187         extractElementOp, extractElementOp.getType(), adaptor.vector(),
188         extractElementOp.position());
189     return success();
190   }
191 };
192 
193 struct VectorInsertElementOpConvert final
194     : public OpConversionPattern<vector::InsertElementOp> {
195   using OpConversionPattern::OpConversionPattern;
196 
197   LogicalResult
198   matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
199                   ConversionPatternRewriter &rewriter) const override {
200     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
201       return failure();
202     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
203         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
204         adaptor.source(), insertElementOp.position());
205     return success();
206   }
207 };
208 
209 struct VectorInsertStridedSliceOpConvert final
210     : public OpConversionPattern<vector::InsertStridedSliceOp> {
211   using OpConversionPattern::OpConversionPattern;
212 
213   LogicalResult
214   matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
215                   ConversionPatternRewriter &rewriter) const override {
216     Value srcVector = adaptor.getOperands().front();
217     Value dstVector = adaptor.getOperands().back();
218 
219     uint64_t stride = getFirstIntValue(insertOp.strides());
220     if (stride != 1)
221       return failure();
222     uint64_t offset = getFirstIntValue(insertOp.offsets());
223 
224     if (srcVector.getType().isa<spirv::ScalarType>()) {
225       assert(!dstVector.getType().isa<spirv::ScalarType>());
226       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
227           insertOp, dstVector.getType(), srcVector, dstVector,
228           rewriter.getI32ArrayAttr(offset));
229       return success();
230     }
231 
232     uint64_t totalSize =
233         dstVector.getType().cast<VectorType>().getNumElements();
234     uint64_t insertSize =
235         srcVector.getType().cast<VectorType>().getNumElements();
236 
237     SmallVector<int32_t, 2> indices(totalSize);
238     std::iota(indices.begin(), indices.end(), 0);
239     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
240               totalSize);
241 
242     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
243         insertOp, dstVector.getType(), dstVector, srcVector,
244         rewriter.getI32ArrayAttr(indices));
245 
246     return success();
247   }
248 };
249 
250 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
251 public:
252   using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
253 
254   LogicalResult
255   matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
256                   ConversionPatternRewriter &rewriter) const override {
257     VectorType dstVecType = op.getType();
258     if (!spirv::CompositeType::isValid(dstVecType))
259       return failure();
260     SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
261     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
262                                                              source);
263     return success();
264   }
265 };
266 
267 } // namespace
268 
269 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
270                                          RewritePatternSet &patterns) {
271   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
272                VectorExtractElementOpConvert, VectorExtractOpConvert,
273                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
274                VectorInsertElementOpConvert, VectorInsertOpConvert,
275                VectorInsertStridedSliceOpConvert, VectorSplatPattern>(
276       typeConverter, patterns.getContext());
277 }
278