1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include <numeric>
27
28 using namespace mlir;
29
30 /// Gets the first integer value from `attr`, assuming it is an integer array
31 /// attribute.
getFirstIntValue(ArrayAttr attr)32 static uint64_t getFirstIntValue(ArrayAttr attr) {
33 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
34 }
35
36 namespace {
37
38 struct VectorBitcastConvert final
39 : public OpConversionPattern<vector::BitCastOp> {
40 using OpConversionPattern::OpConversionPattern;
41
42 LogicalResult
matchAndRewrite__anon0bb484910111::VectorBitcastConvert43 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
44 ConversionPatternRewriter &rewriter) const override {
45 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
46 if (!dstType)
47 return failure();
48
49 if (dstType == adaptor.getSource().getType())
50 rewriter.replaceOp(bitcastOp, adaptor.getSource());
51 else
52 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
53 adaptor.getSource());
54
55 return success();
56 }
57 };
58
59 struct VectorBroadcastConvert final
60 : public OpConversionPattern<vector::BroadcastOp> {
61 using OpConversionPattern::OpConversionPattern;
62
63 LogicalResult
matchAndRewrite__anon0bb484910111::VectorBroadcastConvert64 matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
65 ConversionPatternRewriter &rewriter) const override {
66 if (broadcastOp.getSource().getType().isa<VectorType>() ||
67 !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
68 return failure();
69 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
70 adaptor.getSource());
71 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
72 broadcastOp, broadcastOp.getVectorType(), source);
73 return success();
74 }
75 };
76
77 struct VectorExtractOpConvert final
78 : public OpConversionPattern<vector::ExtractOp> {
79 using OpConversionPattern::OpConversionPattern;
80
81 LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractOpConvert82 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
83 ConversionPatternRewriter &rewriter) const override {
84 // Only support extracting a scalar value now.
85 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
86 if (resultVectorType && resultVectorType.getNumElements() > 1)
87 return failure();
88
89 auto dstType = getTypeConverter()->convertType(extractOp.getType());
90 if (!dstType)
91 return failure();
92
93 if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
94 rewriter.replaceOp(extractOp, adaptor.getVector());
95 return success();
96 }
97
98 int32_t id = getFirstIntValue(extractOp.getPosition());
99 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
100 extractOp, adaptor.getVector(), id);
101 return success();
102 }
103 };
104
105 struct VectorExtractStridedSliceOpConvert final
106 : public OpConversionPattern<vector::ExtractStridedSliceOp> {
107 using OpConversionPattern::OpConversionPattern;
108
109 LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractStridedSliceOpConvert110 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter) const override {
112 auto dstType = getTypeConverter()->convertType(extractOp.getType());
113 if (!dstType)
114 return failure();
115
116 uint64_t offset = getFirstIntValue(extractOp.getOffsets());
117 uint64_t size = getFirstIntValue(extractOp.getSizes());
118 uint64_t stride = getFirstIntValue(extractOp.getStrides());
119 if (stride != 1)
120 return failure();
121
122 Value srcVector = adaptor.getOperands().front();
123
124 // Extract vector<1xT> case.
125 if (dstType.isa<spirv::ScalarType>()) {
126 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
127 srcVector, offset);
128 return success();
129 }
130
131 SmallVector<int32_t, 2> indices(size);
132 std::iota(indices.begin(), indices.end(), offset);
133
134 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
135 extractOp, dstType, srcVector, srcVector,
136 rewriter.getI32ArrayAttr(indices));
137
138 return success();
139 }
140 };
141
142 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
143 using OpConversionPattern::OpConversionPattern;
144
145 LogicalResult
matchAndRewrite__anon0bb484910111::VectorFmaOpConvert146 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 Type dstType = getTypeConverter()->convertType(fmaOp.getType());
149 if (!dstType)
150 return failure();
151 rewriter.replaceOpWithNewOp<spirv::GLFmaOp>(
152 fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
153 return success();
154 }
155 };
156
157 struct VectorInsertOpConvert final
158 : public OpConversionPattern<vector::InsertOp> {
159 using OpConversionPattern::OpConversionPattern;
160
161 LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertOpConvert162 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
163 ConversionPatternRewriter &rewriter) const override {
164 // Special case for inserting scalar values into size-1 vectors.
165 if (insertOp.getSourceType().isIntOrFloat() &&
166 insertOp.getDestVectorType().getNumElements() == 1) {
167 rewriter.replaceOp(insertOp, adaptor.getSource());
168 return success();
169 }
170
171 if (insertOp.getSourceType().isa<VectorType>() ||
172 !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
173 return failure();
174 int32_t id = getFirstIntValue(insertOp.getPosition());
175 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
176 insertOp, adaptor.getSource(), adaptor.getDest(), id);
177 return success();
178 }
179 };
180
181 struct VectorExtractElementOpConvert final
182 : public OpConversionPattern<vector::ExtractElementOp> {
183 using OpConversionPattern::OpConversionPattern;
184
185 LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractElementOpConvert186 matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
187 ConversionPatternRewriter &rewriter) const override {
188 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
189 return failure();
190 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
191 extractElementOp, extractElementOp.getType(), adaptor.getVector(),
192 extractElementOp.getPosition());
193 return success();
194 }
195 };
196
197 struct VectorInsertElementOpConvert final
198 : public OpConversionPattern<vector::InsertElementOp> {
199 using OpConversionPattern::OpConversionPattern;
200
201 LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertElementOpConvert202 matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
203 ConversionPatternRewriter &rewriter) const override {
204 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
205 return failure();
206 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
207 insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
208 adaptor.getSource(), insertElementOp.getPosition());
209 return success();
210 }
211 };
212
213 struct VectorInsertStridedSliceOpConvert final
214 : public OpConversionPattern<vector::InsertStridedSliceOp> {
215 using OpConversionPattern::OpConversionPattern;
216
217 LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertStridedSliceOpConvert218 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const override {
220 Value srcVector = adaptor.getOperands().front();
221 Value dstVector = adaptor.getOperands().back();
222
223 uint64_t stride = getFirstIntValue(insertOp.getStrides());
224 if (stride != 1)
225 return failure();
226 uint64_t offset = getFirstIntValue(insertOp.getOffsets());
227
228 if (srcVector.getType().isa<spirv::ScalarType>()) {
229 assert(!dstVector.getType().isa<spirv::ScalarType>());
230 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
231 insertOp, dstVector.getType(), srcVector, dstVector,
232 rewriter.getI32ArrayAttr(offset));
233 return success();
234 }
235
236 uint64_t totalSize =
237 dstVector.getType().cast<VectorType>().getNumElements();
238 uint64_t insertSize =
239 srcVector.getType().cast<VectorType>().getNumElements();
240
241 SmallVector<int32_t, 2> indices(totalSize);
242 std::iota(indices.begin(), indices.end(), 0);
243 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
244 totalSize);
245
246 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
247 insertOp, dstVector.getType(), dstVector, srcVector,
248 rewriter.getI32ArrayAttr(indices));
249
250 return success();
251 }
252 };
253
254 struct VectorReductionPattern final
255 : public OpConversionPattern<vector::ReductionOp> {
256 using OpConversionPattern::OpConversionPattern;
257
258 LogicalResult
matchAndRewrite__anon0bb484910111::VectorReductionPattern259 matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter) const override {
261 Type resultType = typeConverter->convertType(reduceOp.getType());
262 if (!resultType)
263 return failure();
264
265 auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
266 if (!srcVectorType || srcVectorType.getRank() != 1)
267 return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
268
269 // Extract all elements.
270 int numElements = srcVectorType.getDimSize(0);
271 SmallVector<Value, 4> values;
272 values.reserve(numElements + (adaptor.getAcc() != nullptr));
273 Location loc = reduceOp.getLoc();
274 for (int i = 0; i < numElements; ++i) {
275 values.push_back(rewriter.create<spirv::CompositeExtractOp>(
276 loc, srcVectorType.getElementType(), adaptor.getVector(),
277 rewriter.getI32ArrayAttr({i})));
278 }
279 if (Value acc = adaptor.getAcc())
280 values.push_back(acc);
281
282 // Reduce them.
283 Value result = values.front();
284 for (Value next : llvm::makeArrayRef(values).drop_front()) {
285 switch (reduceOp.getKind()) {
286 #define INT_FLOAT_CASE(kind, iop, fop) \
287 case vector::CombiningKind::kind: \
288 if (resultType.isa<IntegerType>()) { \
289 result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
290 } else { \
291 assert(resultType.isa<FloatType>()); \
292 result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
293 } \
294 break
295
296 INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
297 INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
298
299 case vector::CombiningKind::MINUI:
300 case vector::CombiningKind::MINSI:
301 case vector::CombiningKind::MINF:
302 case vector::CombiningKind::MAXUI:
303 case vector::CombiningKind::MAXSI:
304 case vector::CombiningKind::MAXF:
305 case vector::CombiningKind::AND:
306 case vector::CombiningKind::OR:
307 case vector::CombiningKind::XOR:
308 return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
309 }
310 }
311
312 rewriter.replaceOp(reduceOp, result);
313 return success();
314 }
315 };
316
317 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
318 public:
319 using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
320
321 LogicalResult
matchAndRewrite(vector::SplatOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const322 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter) const override {
324 Type dstType = getTypeConverter()->convertType(op.getType());
325 if (!dstType)
326 return failure();
327 if (dstType.isa<spirv::ScalarType>()) {
328 rewriter.replaceOp(op, adaptor.getInput());
329 } else {
330 auto dstVecType = dstType.cast<VectorType>();
331 SmallVector<Value, 4> source(dstVecType.getNumElements(),
332 adaptor.getInput());
333 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
334 source);
335 }
336 return success();
337 }
338 };
339
340 struct VectorShuffleOpConvert final
341 : public OpConversionPattern<vector::ShuffleOp> {
342 using OpConversionPattern::OpConversionPattern;
343
344 LogicalResult
matchAndRewrite__anon0bb484910111::VectorShuffleOpConvert345 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter) const override {
347 auto oldResultType = shuffleOp.getVectorType();
348 if (!spirv::CompositeType::isValid(oldResultType))
349 return failure();
350 auto newResultType = getTypeConverter()->convertType(oldResultType);
351
352 auto oldSourceType = shuffleOp.getV1VectorType();
353 if (oldSourceType.getNumElements() > 1) {
354 SmallVector<int32_t, 4> components = llvm::to_vector<4>(
355 llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
356 return attr.cast<IntegerAttr>().getValue().getZExtValue();
357 }));
358 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
359 shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
360 rewriter.getI32ArrayAttr(components));
361 return success();
362 }
363
364 SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
365 SmallVector<Value, 4> newOperands;
366 newOperands.reserve(oldResultType.getNumElements());
367 for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
368 newOperands.push_back(oldOperands[i.getZExtValue()]);
369 }
370 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
371 shuffleOp, newResultType, newOperands);
372
373 return success();
374 }
375 };
376
377 } // namespace
378
populateVectorToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)379 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
380 RewritePatternSet &patterns) {
381 patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
382 VectorExtractElementOpConvert, VectorExtractOpConvert,
383 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
384 VectorInsertElementOpConvert, VectorInsertOpConvert,
385 VectorReductionPattern, VectorInsertStridedSliceOpConvert,
386 VectorShuffleOpConvert, VectorSplatPattern>(
387 typeConverter, patterns.getContext());
388 }
389