1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include <numeric>
23 
24 using namespace mlir;
25 
26 /// Gets the first integer value from `attr`, assuming it is an integer array
27 /// attribute.
28 static uint64_t getFirstIntValue(ArrayAttr attr) {
29   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
30 }
31 
32 namespace {
33 
34 struct VectorBitcastConvert final
35     : public OpConversionPattern<vector::BitCastOp> {
36   using OpConversionPattern::OpConversionPattern;
37 
38   LogicalResult
39   matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands,
40                   ConversionPatternRewriter &rewriter) const override {
41     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
42     if (!dstType)
43       return failure();
44 
45     vector::BitCastOp::Adaptor adaptor(operands);
46     if (dstType == adaptor.source().getType())
47       rewriter.replaceOp(bitcastOp, adaptor.source());
48     else
49       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
50                                                     adaptor.source());
51 
52     return success();
53   }
54 };
55 
56 struct VectorBroadcastConvert final
57     : public OpConversionPattern<vector::BroadcastOp> {
58   using OpConversionPattern::OpConversionPattern;
59 
60   LogicalResult
61   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override {
63     if (broadcastOp.source().getType().isa<VectorType>() ||
64         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
65       return failure();
66     vector::BroadcastOp::Adaptor adaptor(operands);
67     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
68                                  adaptor.source());
69     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
70         broadcastOp, broadcastOp.getVectorType(), source);
71     return success();
72   }
73 };
74 
75 struct VectorExtractOpConvert final
76     : public OpConversionPattern<vector::ExtractOp> {
77   using OpConversionPattern::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
81                   ConversionPatternRewriter &rewriter) const override {
82     // Only support extracting a scalar value now.
83     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
84     if (resultVectorType && resultVectorType.getNumElements() > 1)
85       return failure();
86 
87     auto dstType = getTypeConverter()->convertType(extractOp.getType());
88     if (!dstType)
89       return failure();
90 
91     vector::ExtractOp::Adaptor adaptor(operands);
92     if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
93       rewriter.replaceOp(extractOp, adaptor.vector());
94       return success();
95     }
96 
97     int32_t id = getFirstIntValue(extractOp.position());
98     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
99         extractOp, adaptor.vector(), id);
100     return success();
101   }
102 };
103 
104 struct VectorExtractStridedSliceOpConvert final
105     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
106   using OpConversionPattern::OpConversionPattern;
107 
108   LogicalResult
109   matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
110                   ArrayRef<Value> operands,
111                   ConversionPatternRewriter &rewriter) const override {
112     auto dstType = getTypeConverter()->convertType(extractOp.getType());
113     if (!dstType)
114       return failure();
115 
116 
117     uint64_t offset = getFirstIntValue(extractOp.offsets());
118     uint64_t size = getFirstIntValue(extractOp.sizes());
119     uint64_t stride = getFirstIntValue(extractOp.strides());
120     if (stride != 1)
121       return failure();
122 
123     Value srcVector = operands.front();
124 
125     // Extract vector<1xT> case.
126     if (dstType.isa<spirv::ScalarType>()) {
127       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
128                                                              srcVector, offset);
129       return success();
130     }
131 
132     SmallVector<int32_t, 2> indices(size);
133     std::iota(indices.begin(), indices.end(), offset);
134 
135     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
136         extractOp, dstType, srcVector, srcVector,
137         rewriter.getI32ArrayAttr(indices));
138 
139     return success();
140   }
141 };
142 
143 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
144   using OpConversionPattern::OpConversionPattern;
145 
146   LogicalResult
147   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
148                   ConversionPatternRewriter &rewriter) const override {
149     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
150       return failure();
151     vector::FMAOp::Adaptor adaptor(operands);
152     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
153         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
154     return success();
155   }
156 };
157 
158 struct VectorInsertOpConvert final
159     : public OpConversionPattern<vector::InsertOp> {
160   using OpConversionPattern::OpConversionPattern;
161 
162   LogicalResult
163   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
164                   ConversionPatternRewriter &rewriter) const override {
165     if (insertOp.getSourceType().isa<VectorType>() ||
166         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
167       return failure();
168     vector::InsertOp::Adaptor adaptor(operands);
169     int32_t id = getFirstIntValue(insertOp.position());
170     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171         insertOp, adaptor.source(), adaptor.dest(), id);
172     return success();
173   }
174 };
175 
176 struct VectorExtractElementOpConvert final
177     : public OpConversionPattern<vector::ExtractElementOp> {
178   using OpConversionPattern::OpConversionPattern;
179 
180   LogicalResult
181   matchAndRewrite(vector::ExtractElementOp extractElementOp,
182                   ArrayRef<Value> operands,
183                   ConversionPatternRewriter &rewriter) const override {
184     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
185       return failure();
186     vector::ExtractElementOp::Adaptor adaptor(operands);
187     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
188         extractElementOp, extractElementOp.getType(), adaptor.vector(),
189         extractElementOp.position());
190     return success();
191   }
192 };
193 
194 struct VectorInsertElementOpConvert final
195     : public OpConversionPattern<vector::InsertElementOp> {
196   using OpConversionPattern::OpConversionPattern;
197 
198   LogicalResult
199   matchAndRewrite(vector::InsertElementOp insertElementOp,
200                   ArrayRef<Value> operands,
201                   ConversionPatternRewriter &rewriter) const override {
202     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
203       return failure();
204     vector::InsertElementOp::Adaptor adaptor(operands);
205     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
206         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
207         adaptor.source(), insertElementOp.position());
208     return success();
209   }
210 };
211 
212 struct VectorInsertStridedSliceOpConvert final
213     : public OpConversionPattern<vector::InsertStridedSliceOp> {
214   using OpConversionPattern::OpConversionPattern;
215 
216   LogicalResult
217   matchAndRewrite(vector::InsertStridedSliceOp insertOp,
218                   ArrayRef<Value> operands,
219                   ConversionPatternRewriter &rewriter) const override {
220     Value srcVector = operands.front();
221     Value dstVector = operands.back();
222 
223     // Insert scalar values not supported yet.
224     if (srcVector.getType().isa<spirv::ScalarType>() ||
225         dstVector.getType().isa<spirv::ScalarType>())
226       return failure();
227 
228     uint64_t stride = getFirstIntValue(insertOp.strides());
229     if (stride != 1)
230       return failure();
231 
232     uint64_t totalSize =
233         dstVector.getType().cast<VectorType>().getNumElements();
234     uint64_t insertSize =
235         srcVector.getType().cast<VectorType>().getNumElements();
236     uint64_t offset = getFirstIntValue(insertOp.offsets());
237 
238     SmallVector<int32_t, 2> indices(totalSize);
239     std::iota(indices.begin(), indices.end(), 0);
240     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
241               totalSize);
242 
243     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
244         insertOp, dstVector.getType(), dstVector, srcVector,
245         rewriter.getI32ArrayAttr(indices));
246 
247     return success();
248   }
249 };
250 
251 } // namespace
252 
253 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
254                                          RewritePatternSet &patterns) {
255   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
256                VectorExtractElementOpConvert, VectorExtractOpConvert,
257                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
258                VectorInsertElementOpConvert, VectorInsertOpConvert,
259                VectorInsertStridedSliceOpConvert>(typeConverter,
260                                                   patterns.getContext());
261 }
262