1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include <numeric>
23 
24 using namespace mlir;
25 
26 /// Gets the first integer value from `attr`, assuming it is an integer array
27 /// attribute.
28 static uint64_t getFirstIntValue(ArrayAttr attr) {
29   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
30 }
31 
32 namespace {
33 
34 struct VectorBitcastConvert final
35     : public OpConversionPattern<vector::BitCastOp> {
36   using OpConversionPattern::OpConversionPattern;
37 
38   LogicalResult
39   matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands,
40                   ConversionPatternRewriter &rewriter) const override {
41     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
42     if (!dstType)
43       return failure();
44 
45     vector::BitCastOp::Adaptor adaptor(operands);
46     if (dstType == adaptor.source().getType())
47       rewriter.replaceOp(bitcastOp, adaptor.source());
48     else
49       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
50                                                     adaptor.source());
51 
52     return success();
53   }
54 };
55 
56 struct VectorBroadcastConvert final
57     : public OpConversionPattern<vector::BroadcastOp> {
58   using OpConversionPattern::OpConversionPattern;
59 
60   LogicalResult
61   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override {
63     if (broadcastOp.source().getType().isa<VectorType>() ||
64         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
65       return failure();
66     vector::BroadcastOp::Adaptor adaptor(operands);
67     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
68                                  adaptor.source());
69     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
70         broadcastOp, broadcastOp.getVectorType(), source);
71     return success();
72   }
73 };
74 
75 struct VectorExtractOpConvert final
76     : public OpConversionPattern<vector::ExtractOp> {
77   using OpConversionPattern::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
81                   ConversionPatternRewriter &rewriter) const override {
82     // Only support extracting a scalar value now.
83     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
84     if (resultVectorType && resultVectorType.getNumElements() > 1)
85       return failure();
86 
87     auto dstType = getTypeConverter()->convertType(extractOp.getType());
88     if (!dstType)
89       return failure();
90 
91     vector::ExtractOp::Adaptor adaptor(operands);
92     int32_t id = getFirstIntValue(extractOp.position());
93     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
94         extractOp, adaptor.vector(), id);
95     return success();
96   }
97 };
98 
99 struct VectorExtractStridedSliceOpConvert final
100     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
101   using OpConversionPattern::OpConversionPattern;
102 
103   LogicalResult
104   matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
105                   ArrayRef<Value> operands,
106                   ConversionPatternRewriter &rewriter) const override {
107     auto dstType = getTypeConverter()->convertType(extractOp.getType());
108     if (!dstType)
109       return failure();
110 
111     // Extract vector<1xT> not supported yet.
112     if (dstType.isa<spirv::ScalarType>())
113       return failure();
114 
115     uint64_t offset = getFirstIntValue(extractOp.offsets());
116     uint64_t size = getFirstIntValue(extractOp.sizes());
117     uint64_t stride = getFirstIntValue(extractOp.strides());
118     if (stride != 1)
119       return failure();
120 
121     Value srcVector = operands.front();
122 
123     SmallVector<int32_t, 2> indices(size);
124     std::iota(indices.begin(), indices.end(), offset);
125 
126     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
127         extractOp, dstType, srcVector, srcVector,
128         rewriter.getI32ArrayAttr(indices));
129 
130     return success();
131   }
132 };
133 
134 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
135   using OpConversionPattern::OpConversionPattern;
136 
137   LogicalResult
138   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
139                   ConversionPatternRewriter &rewriter) const override {
140     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
141       return failure();
142     vector::FMAOp::Adaptor adaptor(operands);
143     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
144         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
145     return success();
146   }
147 };
148 
149 struct VectorInsertOpConvert final
150     : public OpConversionPattern<vector::InsertOp> {
151   using OpConversionPattern::OpConversionPattern;
152 
153   LogicalResult
154   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
155                   ConversionPatternRewriter &rewriter) const override {
156     if (insertOp.getSourceType().isa<VectorType>() ||
157         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
158       return failure();
159     vector::InsertOp::Adaptor adaptor(operands);
160     int32_t id = getFirstIntValue(insertOp.position());
161     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
162         insertOp, adaptor.source(), adaptor.dest(), id);
163     return success();
164   }
165 };
166 
167 struct VectorExtractElementOpConvert final
168     : public OpConversionPattern<vector::ExtractElementOp> {
169   using OpConversionPattern::OpConversionPattern;
170 
171   LogicalResult
172   matchAndRewrite(vector::ExtractElementOp extractElementOp,
173                   ArrayRef<Value> operands,
174                   ConversionPatternRewriter &rewriter) const override {
175     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
176       return failure();
177     vector::ExtractElementOp::Adaptor adaptor(operands);
178     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
179         extractElementOp, extractElementOp.getType(), adaptor.vector(),
180         extractElementOp.position());
181     return success();
182   }
183 };
184 
185 struct VectorInsertElementOpConvert final
186     : public OpConversionPattern<vector::InsertElementOp> {
187   using OpConversionPattern::OpConversionPattern;
188 
189   LogicalResult
190   matchAndRewrite(vector::InsertElementOp insertElementOp,
191                   ArrayRef<Value> operands,
192                   ConversionPatternRewriter &rewriter) const override {
193     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
194       return failure();
195     vector::InsertElementOp::Adaptor adaptor(operands);
196     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
197         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
198         adaptor.source(), insertElementOp.position());
199     return success();
200   }
201 };
202 
203 struct VectorInsertStridedSliceOpConvert final
204     : public OpConversionPattern<vector::InsertStridedSliceOp> {
205   using OpConversionPattern::OpConversionPattern;
206 
207   LogicalResult
208   matchAndRewrite(vector::InsertStridedSliceOp insertOp,
209                   ArrayRef<Value> operands,
210                   ConversionPatternRewriter &rewriter) const override {
211     Value srcVector = operands.front();
212     Value dstVector = operands.back();
213 
214     // Insert scalar values not supported yet.
215     if (srcVector.getType().isa<spirv::ScalarType>() ||
216         dstVector.getType().isa<spirv::ScalarType>())
217       return failure();
218 
219     uint64_t stride = getFirstIntValue(insertOp.strides());
220     if (stride != 1)
221       return failure();
222 
223     uint64_t totalSize =
224         dstVector.getType().cast<VectorType>().getNumElements();
225     uint64_t insertSize =
226         srcVector.getType().cast<VectorType>().getNumElements();
227     uint64_t offset = getFirstIntValue(insertOp.offsets());
228 
229     SmallVector<int32_t, 2> indices(totalSize);
230     std::iota(indices.begin(), indices.end(), 0);
231     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
232               totalSize);
233 
234     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
235         insertOp, dstVector.getType(), dstVector, srcVector,
236         rewriter.getI32ArrayAttr(indices));
237 
238     return success();
239   }
240 };
241 
242 } // namespace
243 
244 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
245                                          RewritePatternSet &patterns) {
246   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
247                VectorExtractElementOpConvert, VectorExtractOpConvert,
248                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
249                VectorInsertElementOpConvert, VectorInsertOpConvert,
250                VectorInsertStridedSliceOpConvert>(typeConverter,
251                                                   patterns.getContext());
252 }
253