1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include <numeric>
23 
24 using namespace mlir;
25 
26 /// Gets the first integer value from `attr`, assuming it is an integer array
27 /// attribute.
28 static uint64_t getFirstIntValue(ArrayAttr attr) {
29   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
30 }
31 
32 namespace {
33 
34 struct VectorBitcastConvert final
35     : public OpConversionPattern<vector::BitCastOp> {
36   using OpConversionPattern::OpConversionPattern;
37 
38   LogicalResult
39   matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands,
40                   ConversionPatternRewriter &rewriter) const override {
41     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
42     if (!dstType)
43       return failure();
44 
45     vector::BitCastOp::Adaptor adaptor(operands);
46     if (dstType == adaptor.source().getType())
47       rewriter.replaceOp(bitcastOp, adaptor.source());
48     else
49       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
50                                                     adaptor.source());
51 
52     return success();
53   }
54 };
55 
56 struct VectorBroadcastConvert final
57     : public OpConversionPattern<vector::BroadcastOp> {
58   using OpConversionPattern::OpConversionPattern;
59 
60   LogicalResult
61   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
62                   ConversionPatternRewriter &rewriter) const override {
63     if (broadcastOp.source().getType().isa<VectorType>() ||
64         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
65       return failure();
66     vector::BroadcastOp::Adaptor adaptor(operands);
67     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
68                                  adaptor.source());
69     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
70         broadcastOp, broadcastOp.getVectorType(), source);
71     return success();
72   }
73 };
74 
75 struct VectorExtractOpConvert final
76     : public OpConversionPattern<vector::ExtractOp> {
77   using OpConversionPattern::OpConversionPattern;
78 
79   LogicalResult
80   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
81                   ConversionPatternRewriter &rewriter) const override {
82     // Only support extracting a scalar value now.
83     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
84     if (resultVectorType && resultVectorType.getNumElements() > 1)
85       return failure();
86 
87     auto dstType = getTypeConverter()->convertType(extractOp.getType());
88     if (!dstType)
89       return failure();
90 
91     vector::ExtractOp::Adaptor adaptor(operands);
92     if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
93       rewriter.replaceOp(extractOp, adaptor.vector());
94       return success();
95     }
96 
97     int32_t id = getFirstIntValue(extractOp.position());
98     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
99         extractOp, adaptor.vector(), id);
100     return success();
101   }
102 };
103 
104 struct VectorExtractStridedSliceOpConvert final
105     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
106   using OpConversionPattern::OpConversionPattern;
107 
108   LogicalResult
109   matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
110                   ArrayRef<Value> operands,
111                   ConversionPatternRewriter &rewriter) const override {
112     auto dstType = getTypeConverter()->convertType(extractOp.getType());
113     if (!dstType)
114       return failure();
115 
116     // Extract vector<1xT> not supported yet.
117     if (dstType.isa<spirv::ScalarType>())
118       return failure();
119 
120     uint64_t offset = getFirstIntValue(extractOp.offsets());
121     uint64_t size = getFirstIntValue(extractOp.sizes());
122     uint64_t stride = getFirstIntValue(extractOp.strides());
123     if (stride != 1)
124       return failure();
125 
126     Value srcVector = operands.front();
127 
128     SmallVector<int32_t, 2> indices(size);
129     std::iota(indices.begin(), indices.end(), offset);
130 
131     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
132         extractOp, dstType, srcVector, srcVector,
133         rewriter.getI32ArrayAttr(indices));
134 
135     return success();
136   }
137 };
138 
139 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
140   using OpConversionPattern::OpConversionPattern;
141 
142   LogicalResult
143   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
144                   ConversionPatternRewriter &rewriter) const override {
145     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
146       return failure();
147     vector::FMAOp::Adaptor adaptor(operands);
148     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
149         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
150     return success();
151   }
152 };
153 
154 struct VectorInsertOpConvert final
155     : public OpConversionPattern<vector::InsertOp> {
156   using OpConversionPattern::OpConversionPattern;
157 
158   LogicalResult
159   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
160                   ConversionPatternRewriter &rewriter) const override {
161     if (insertOp.getSourceType().isa<VectorType>() ||
162         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
163       return failure();
164     vector::InsertOp::Adaptor adaptor(operands);
165     int32_t id = getFirstIntValue(insertOp.position());
166     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
167         insertOp, adaptor.source(), adaptor.dest(), id);
168     return success();
169   }
170 };
171 
172 struct VectorExtractElementOpConvert final
173     : public OpConversionPattern<vector::ExtractElementOp> {
174   using OpConversionPattern::OpConversionPattern;
175 
176   LogicalResult
177   matchAndRewrite(vector::ExtractElementOp extractElementOp,
178                   ArrayRef<Value> operands,
179                   ConversionPatternRewriter &rewriter) const override {
180     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
181       return failure();
182     vector::ExtractElementOp::Adaptor adaptor(operands);
183     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
184         extractElementOp, extractElementOp.getType(), adaptor.vector(),
185         extractElementOp.position());
186     return success();
187   }
188 };
189 
190 struct VectorInsertElementOpConvert final
191     : public OpConversionPattern<vector::InsertElementOp> {
192   using OpConversionPattern::OpConversionPattern;
193 
194   LogicalResult
195   matchAndRewrite(vector::InsertElementOp insertElementOp,
196                   ArrayRef<Value> operands,
197                   ConversionPatternRewriter &rewriter) const override {
198     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
199       return failure();
200     vector::InsertElementOp::Adaptor adaptor(operands);
201     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
202         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
203         adaptor.source(), insertElementOp.position());
204     return success();
205   }
206 };
207 
208 struct VectorInsertStridedSliceOpConvert final
209     : public OpConversionPattern<vector::InsertStridedSliceOp> {
210   using OpConversionPattern::OpConversionPattern;
211 
212   LogicalResult
213   matchAndRewrite(vector::InsertStridedSliceOp insertOp,
214                   ArrayRef<Value> operands,
215                   ConversionPatternRewriter &rewriter) const override {
216     Value srcVector = operands.front();
217     Value dstVector = operands.back();
218 
219     // Insert scalar values not supported yet.
220     if (srcVector.getType().isa<spirv::ScalarType>() ||
221         dstVector.getType().isa<spirv::ScalarType>())
222       return failure();
223 
224     uint64_t stride = getFirstIntValue(insertOp.strides());
225     if (stride != 1)
226       return failure();
227 
228     uint64_t totalSize =
229         dstVector.getType().cast<VectorType>().getNumElements();
230     uint64_t insertSize =
231         srcVector.getType().cast<VectorType>().getNumElements();
232     uint64_t offset = getFirstIntValue(insertOp.offsets());
233 
234     SmallVector<int32_t, 2> indices(totalSize);
235     std::iota(indices.begin(), indices.end(), 0);
236     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
237               totalSize);
238 
239     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
240         insertOp, dstVector.getType(), dstVector, srcVector,
241         rewriter.getI32ArrayAttr(indices));
242 
243     return success();
244   }
245 };
246 
247 } // namespace
248 
249 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
250                                          RewritePatternSet &patterns) {
251   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
252                VectorExtractElementOpConvert, VectorExtractOpConvert,
253                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
254                VectorInsertElementOpConvert, VectorInsertOpConvert,
255                VectorInsertStridedSliceOpConvert>(typeConverter,
256                                                   patterns.getContext());
257 }
258