1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include <numeric>
27 
28 using namespace mlir;
29 
30 /// Gets the first integer value from `attr`, assuming it is an integer array
31 /// attribute.
getFirstIntValue(ArrayAttr attr)32 static uint64_t getFirstIntValue(ArrayAttr attr) {
33   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
34 }
35 
36 namespace {
37 
38 struct VectorBitcastConvert final
39     : public OpConversionPattern<vector::BitCastOp> {
40   using OpConversionPattern::OpConversionPattern;
41 
42   LogicalResult
matchAndRewrite__anon0bb484910111::VectorBitcastConvert43   matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
44                   ConversionPatternRewriter &rewriter) const override {
45     auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
46     if (!dstType)
47       return failure();
48 
49     if (dstType == adaptor.getSource().getType())
50       rewriter.replaceOp(bitcastOp, adaptor.getSource());
51     else
52       rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
53                                                     adaptor.getSource());
54 
55     return success();
56   }
57 };
58 
59 struct VectorBroadcastConvert final
60     : public OpConversionPattern<vector::BroadcastOp> {
61   using OpConversionPattern::OpConversionPattern;
62 
63   LogicalResult
matchAndRewrite__anon0bb484910111::VectorBroadcastConvert64   matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
65                   ConversionPatternRewriter &rewriter) const override {
66     if (broadcastOp.getSource().getType().isa<VectorType>() ||
67         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
68       return failure();
69     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
70                                  adaptor.getSource());
71     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
72         broadcastOp, broadcastOp.getVectorType(), source);
73     return success();
74   }
75 };
76 
77 struct VectorExtractOpConvert final
78     : public OpConversionPattern<vector::ExtractOp> {
79   using OpConversionPattern::OpConversionPattern;
80 
81   LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractOpConvert82   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
83                   ConversionPatternRewriter &rewriter) const override {
84     // Only support extracting a scalar value now.
85     VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
86     if (resultVectorType && resultVectorType.getNumElements() > 1)
87       return failure();
88 
89     auto dstType = getTypeConverter()->convertType(extractOp.getType());
90     if (!dstType)
91       return failure();
92 
93     if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
94       rewriter.replaceOp(extractOp, adaptor.getVector());
95       return success();
96     }
97 
98     int32_t id = getFirstIntValue(extractOp.getPosition());
99     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
100         extractOp, adaptor.getVector(), id);
101     return success();
102   }
103 };
104 
105 struct VectorExtractStridedSliceOpConvert final
106     : public OpConversionPattern<vector::ExtractStridedSliceOp> {
107   using OpConversionPattern::OpConversionPattern;
108 
109   LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractStridedSliceOpConvert110   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
111                   ConversionPatternRewriter &rewriter) const override {
112     auto dstType = getTypeConverter()->convertType(extractOp.getType());
113     if (!dstType)
114       return failure();
115 
116     uint64_t offset = getFirstIntValue(extractOp.getOffsets());
117     uint64_t size = getFirstIntValue(extractOp.getSizes());
118     uint64_t stride = getFirstIntValue(extractOp.getStrides());
119     if (stride != 1)
120       return failure();
121 
122     Value srcVector = adaptor.getOperands().front();
123 
124     // Extract vector<1xT> case.
125     if (dstType.isa<spirv::ScalarType>()) {
126       rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
127                                                              srcVector, offset);
128       return success();
129     }
130 
131     SmallVector<int32_t, 2> indices(size);
132     std::iota(indices.begin(), indices.end(), offset);
133 
134     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
135         extractOp, dstType, srcVector, srcVector,
136         rewriter.getI32ArrayAttr(indices));
137 
138     return success();
139   }
140 };
141 
142 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
143   using OpConversionPattern::OpConversionPattern;
144 
145   LogicalResult
matchAndRewrite__anon0bb484910111::VectorFmaOpConvert146   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
147                   ConversionPatternRewriter &rewriter) const override {
148     Type dstType = getTypeConverter()->convertType(fmaOp.getType());
149     if (!dstType)
150       return failure();
151     rewriter.replaceOpWithNewOp<spirv::GLFmaOp>(
152         fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
153     return success();
154   }
155 };
156 
157 struct VectorInsertOpConvert final
158     : public OpConversionPattern<vector::InsertOp> {
159   using OpConversionPattern::OpConversionPattern;
160 
161   LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertOpConvert162   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
163                   ConversionPatternRewriter &rewriter) const override {
164     // Special case for inserting scalar values into size-1 vectors.
165     if (insertOp.getSourceType().isIntOrFloat() &&
166         insertOp.getDestVectorType().getNumElements() == 1) {
167       rewriter.replaceOp(insertOp, adaptor.getSource());
168       return success();
169     }
170 
171     if (insertOp.getSourceType().isa<VectorType>() ||
172         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
173       return failure();
174     int32_t id = getFirstIntValue(insertOp.getPosition());
175     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
176         insertOp, adaptor.getSource(), adaptor.getDest(), id);
177     return success();
178   }
179 };
180 
181 struct VectorExtractElementOpConvert final
182     : public OpConversionPattern<vector::ExtractElementOp> {
183   using OpConversionPattern::OpConversionPattern;
184 
185   LogicalResult
matchAndRewrite__anon0bb484910111::VectorExtractElementOpConvert186   matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor,
187                   ConversionPatternRewriter &rewriter) const override {
188     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
189       return failure();
190     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
191         extractElementOp, extractElementOp.getType(), adaptor.getVector(),
192         extractElementOp.getPosition());
193     return success();
194   }
195 };
196 
197 struct VectorInsertElementOpConvert final
198     : public OpConversionPattern<vector::InsertElementOp> {
199   using OpConversionPattern::OpConversionPattern;
200 
201   LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertElementOpConvert202   matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor,
203                   ConversionPatternRewriter &rewriter) const override {
204     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
205       return failure();
206     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
207         insertElementOp, insertElementOp.getType(), insertElementOp.getDest(),
208         adaptor.getSource(), insertElementOp.getPosition());
209     return success();
210   }
211 };
212 
213 struct VectorInsertStridedSliceOpConvert final
214     : public OpConversionPattern<vector::InsertStridedSliceOp> {
215   using OpConversionPattern::OpConversionPattern;
216 
217   LogicalResult
matchAndRewrite__anon0bb484910111::VectorInsertStridedSliceOpConvert218   matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
219                   ConversionPatternRewriter &rewriter) const override {
220     Value srcVector = adaptor.getOperands().front();
221     Value dstVector = adaptor.getOperands().back();
222 
223     uint64_t stride = getFirstIntValue(insertOp.getStrides());
224     if (stride != 1)
225       return failure();
226     uint64_t offset = getFirstIntValue(insertOp.getOffsets());
227 
228     if (srcVector.getType().isa<spirv::ScalarType>()) {
229       assert(!dstVector.getType().isa<spirv::ScalarType>());
230       rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
231           insertOp, dstVector.getType(), srcVector, dstVector,
232           rewriter.getI32ArrayAttr(offset));
233       return success();
234     }
235 
236     uint64_t totalSize =
237         dstVector.getType().cast<VectorType>().getNumElements();
238     uint64_t insertSize =
239         srcVector.getType().cast<VectorType>().getNumElements();
240 
241     SmallVector<int32_t, 2> indices(totalSize);
242     std::iota(indices.begin(), indices.end(), 0);
243     std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
244               totalSize);
245 
246     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
247         insertOp, dstVector.getType(), dstVector, srcVector,
248         rewriter.getI32ArrayAttr(indices));
249 
250     return success();
251   }
252 };
253 
254 struct VectorReductionPattern final
255     : public OpConversionPattern<vector::ReductionOp> {
256   using OpConversionPattern::OpConversionPattern;
257 
258   LogicalResult
matchAndRewrite__anon0bb484910111::VectorReductionPattern259   matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
260                   ConversionPatternRewriter &rewriter) const override {
261     Type resultType = typeConverter->convertType(reduceOp.getType());
262     if (!resultType)
263       return failure();
264 
265     auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
266     if (!srcVectorType || srcVectorType.getRank() != 1)
267       return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
268 
269     // Extract all elements.
270     int numElements = srcVectorType.getDimSize(0);
271     SmallVector<Value, 4> values;
272     values.reserve(numElements + (adaptor.getAcc() != nullptr));
273     Location loc = reduceOp.getLoc();
274     for (int i = 0; i < numElements; ++i) {
275       values.push_back(rewriter.create<spirv::CompositeExtractOp>(
276           loc, srcVectorType.getElementType(), adaptor.getVector(),
277           rewriter.getI32ArrayAttr({i})));
278     }
279     if (Value acc = adaptor.getAcc())
280       values.push_back(acc);
281 
282     // Reduce them.
283     Value result = values.front();
284     for (Value next : llvm::makeArrayRef(values).drop_front()) {
285       switch (reduceOp.getKind()) {
286 #define INT_FLOAT_CASE(kind, iop, fop)                                         \
287   case vector::CombiningKind::kind:                                            \
288     if (resultType.isa<IntegerType>()) {                                       \
289       result = rewriter.create<spirv::iop>(loc, resultType, result, next);     \
290     } else {                                                                   \
291       assert(resultType.isa<FloatType>());                                     \
292       result = rewriter.create<spirv::fop>(loc, resultType, result, next);     \
293     }                                                                          \
294     break
295 
296         INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
297         INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
298 
299       case vector::CombiningKind::MINUI:
300       case vector::CombiningKind::MINSI:
301       case vector::CombiningKind::MINF:
302       case vector::CombiningKind::MAXUI:
303       case vector::CombiningKind::MAXSI:
304       case vector::CombiningKind::MAXF:
305       case vector::CombiningKind::AND:
306       case vector::CombiningKind::OR:
307       case vector::CombiningKind::XOR:
308         return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
309       }
310     }
311 
312     rewriter.replaceOp(reduceOp, result);
313     return success();
314   }
315 };
316 
317 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
318 public:
319   using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
320 
321   LogicalResult
matchAndRewrite(vector::SplatOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const322   matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
323                   ConversionPatternRewriter &rewriter) const override {
324     Type dstType = getTypeConverter()->convertType(op.getType());
325     if (!dstType)
326       return failure();
327     if (dstType.isa<spirv::ScalarType>()) {
328       rewriter.replaceOp(op, adaptor.getInput());
329     } else {
330       auto dstVecType = dstType.cast<VectorType>();
331       SmallVector<Value, 4> source(dstVecType.getNumElements(),
332                                    adaptor.getInput());
333       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
334                                                                source);
335     }
336     return success();
337   }
338 };
339 
340 struct VectorShuffleOpConvert final
341     : public OpConversionPattern<vector::ShuffleOp> {
342   using OpConversionPattern::OpConversionPattern;
343 
344   LogicalResult
matchAndRewrite__anon0bb484910111::VectorShuffleOpConvert345   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
346                   ConversionPatternRewriter &rewriter) const override {
347     auto oldResultType = shuffleOp.getVectorType();
348     if (!spirv::CompositeType::isValid(oldResultType))
349       return failure();
350     auto newResultType = getTypeConverter()->convertType(oldResultType);
351 
352     auto oldSourceType = shuffleOp.getV1VectorType();
353     if (oldSourceType.getNumElements() > 1) {
354       SmallVector<int32_t, 4> components = llvm::to_vector<4>(
355           llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
356             return attr.cast<IntegerAttr>().getValue().getZExtValue();
357           }));
358       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
359           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
360           rewriter.getI32ArrayAttr(components));
361       return success();
362     }
363 
364     SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
365     SmallVector<Value, 4> newOperands;
366     newOperands.reserve(oldResultType.getNumElements());
367     for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
368       newOperands.push_back(oldOperands[i.getZExtValue()]);
369     }
370     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
371         shuffleOp, newResultType, newOperands);
372 
373     return success();
374   }
375 };
376 
377 } // namespace
378 
populateVectorToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)379 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
380                                          RewritePatternSet &patterns) {
381   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
382                VectorExtractElementOpConvert, VectorExtractOpConvert,
383                VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
384                VectorInsertElementOpConvert, VectorInsertOpConvert,
385                VectorReductionPattern, VectorInsertStridedSliceOpConvert,
386                VectorShuffleOpConvert, VectorSplatPattern>(
387       typeConverter, patterns.getContext());
388 }
389