1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include <numeric>
23 
24 using namespace mlir;
25 
26 /// Gets the first integer value from `attr`, assuming it is an integer array
27 /// attribute.
28 static uint64_t getFirstIntValue(ArrayAttr attr) {
29   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
30 }
31 
32 namespace {
33 
34 struct VectorBitcastConvert final
35     : public OpConversionPattern<vector::BitCastOp> {
36   using OpConversionPattern::OpConversionPattern;
37 
38   LogicalResult
39   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
40                   ConversionPatternRewriter &rewriter) const override {
41     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
42     if (!dstType)
43       return failure();
44 
45     if (dstType == adaptor.source().getType())
46       rewriter.replaceOp(bitcastOp, adaptor.source());
47     else
48       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
49                                                     adaptor.source());
50 
51     return success();
52   }
53 };
54 
55 struct VectorBroadcastConvert final
56     : public OpConversionPattern<vector::BroadcastOp> {
57   using OpConversionPattern::OpConversionPattern;
58 
59   LogicalResult
60   matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
61                   ConversionPatternRewriter &rewriter) const override {
62     if (broadcastOp.source().getType().isa<VectorType>() ||
63         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
64       return failure();
65     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
66                                  adaptor.source());
67     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
68         broadcastOp, broadcastOp.getVectorType(), source);
69     return success();
70   }
71 };
72 
73 struct VectorExtractOpConvert final
74     : public OpConversionPattern<vector::ExtractOp> {
75   using OpConversionPattern::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override {
80     // Only support extracting a scalar value now.
81     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
82     if (resultVectorType && resultVectorType.getNumElements() > 1)
83       return failure();
84 
85     auto dstType = getTypeConverter()->convertType(extractOp.getType());
86     if (!dstType)
87       return failure();
88 
89     if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
90       rewriter.replaceOp(extractOp, adaptor.vector());
91       return success();
92     }
93 
94     int32_t id = getFirstIntValue(extractOp.position());
95     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
96         extractOp, adaptor.vector(), id);
97     return success();
98   }
99 };
100 
101 struct VectorExtractStridedSliceOpConvert final
102     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
103   using OpConversionPattern::OpConversionPattern;
104 
105   LogicalResult
106   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
107                   ConversionPatternRewriter &rewriter) const override {
108     auto dstType = getTypeConverter()->convertType(extractOp.getType());
109     if (!dstType)
110       return failure();
111 
112 
113     uint64_t offset = getFirstIntValue(extractOp.offsets());
114     uint64_t size = getFirstIntValue(extractOp.sizes());
115     uint64_t stride = getFirstIntValue(extractOp.strides());
116     if (stride != 1)
117       return failure();
118 
119     Value srcVector = adaptor.getOperands().front();
120 
121     // Extract vector<1xT> case.
122     if (dstType.isa<spirv::ScalarType>()) {
123       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
124                                                              srcVector, offset);
125       return success();
126     }
127 
128     SmallVector<int32_t, 2> indices(size);
129     std::iota(indices.begin(), indices.end(), offset);
130 
131     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
132         extractOp, dstType, srcVector, srcVector,
133         rewriter.getI32ArrayAttr(indices));
134 
135     return success();
136   }
137 };
138 
139 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
140   using OpConversionPattern::OpConversionPattern;
141 
142   LogicalResult
143   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
144                   ConversionPatternRewriter &rewriter) const override {
145     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
146       return failure();
147     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
148         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
149     return success();
150   }
151 };
152 
153 struct VectorInsertOpConvert final
154     : public OpConversionPattern<vector::InsertOp> {
155   using OpConversionPattern::OpConversionPattern;
156 
157   LogicalResult
158   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override {
160     if (insertOp.getSourceType().isa<VectorType>() ||
161         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
162       return failure();
163     int32_t id = getFirstIntValue(insertOp.position());
164     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
165         insertOp, adaptor.source(), adaptor.dest(), id);
166     return success();
167   }
168 };
169 
170 struct VectorExtractElementOpConvert final
171     : public OpConversionPattern<vector::ExtractElementOp> {
172   using OpConversionPattern::OpConversionPattern;
173 
174   LogicalResult
175   matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
176                   ConversionPatternRewriter &rewriter) const override {
177     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
178       return failure();
179     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
180         extractElementOp, extractElementOp.getType(), adaptor.vector(),
181         extractElementOp.position());
182     return success();
183   }
184 };
185 
186 struct VectorInsertElementOpConvert final
187     : public OpConversionPattern<vector::InsertElementOp> {
188   using OpConversionPattern::OpConversionPattern;
189 
190   LogicalResult
191   matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
192                   ConversionPatternRewriter &rewriter) const override {
193     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
194       return failure();
195     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
196         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
197         adaptor.source(), insertElementOp.position());
198     return success();
199   }
200 };
201 
202 struct VectorInsertStridedSliceOpConvert final
203     : public OpConversionPattern<vector::InsertStridedSliceOp> {
204   using OpConversionPattern::OpConversionPattern;
205 
206   LogicalResult
207   matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
208                   ConversionPatternRewriter &rewriter) const override {
209     Value srcVector = adaptor.getOperands().front();
210     Value dstVector = adaptor.getOperands().back();
211 
212     // Insert scalar values not supported yet.
213     if (srcVector.getType().isa<spirv::ScalarType>() ||
214         dstVector.getType().isa<spirv::ScalarType>())
215       return failure();
216 
217     uint64_t stride = getFirstIntValue(insertOp.strides());
218     if (stride != 1)
219       return failure();
220 
221     uint64_t totalSize =
222         dstVector.getType().cast<VectorType>().getNumElements();
223     uint64_t insertSize =
224         srcVector.getType().cast<VectorType>().getNumElements();
225     uint64_t offset = getFirstIntValue(insertOp.offsets());
226 
227     SmallVector<int32_t, 2> indices(totalSize);
228     std::iota(indices.begin(), indices.end(), 0);
229     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
230               totalSize);
231 
232     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
233         insertOp, dstVector.getType(), dstVector, srcVector,
234         rewriter.getI32ArrayAttr(indices));
235 
236     return success();
237   }
238 };
239 
240 } // namespace
241 
242 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
243                                          RewritePatternSet &patterns) {
244   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
245                VectorExtractElementOpConvert, VectorExtractOpConvert,
246                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
247                VectorInsertElementOpConvert, VectorInsertOpConvert,
248                VectorInsertStridedSliceOpConvert>(typeConverter,
249                                                   patterns.getContext());
250 }
251