1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct VectorBroadcastConvert final
27     : public OpConversionPattern<vector::BroadcastOp> {
28   using OpConversionPattern::OpConversionPattern;
29 
30   LogicalResult
31   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
32                   ConversionPatternRewriter &rewriter) const override {
33     if (broadcastOp.source().getType().isa<VectorType>() ||
34         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
35       return failure();
36     vector::BroadcastOp::Adaptor adaptor(operands);
37     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
38                                  adaptor.source());
39     Value construct = rewriter.create<spirv::CompositeConstructOp>(
40         broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
41     rewriter.replaceOp(broadcastOp, construct);
42     return success();
43   }
44 };
45 
46 struct VectorExtractOpConvert final
47     : public OpConversionPattern<vector::ExtractOp> {
48   using OpConversionPattern::OpConversionPattern;
49 
50   LogicalResult
51   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
52                   ConversionPatternRewriter &rewriter) const override {
53     if (extractOp.getType().isa<VectorType>() ||
54         !spirv::CompositeType::isValid(extractOp.getVectorType()))
55       return failure();
56     vector::ExtractOp::Adaptor adaptor(operands);
57     int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
58     Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
59         extractOp.getLoc(), adaptor.vector(), id);
60     rewriter.replaceOp(extractOp, newExtract);
61     return success();
62   }
63 };
64 
65 struct VectorInsertOpConvert final
66     : public OpConversionPattern<vector::InsertOp> {
67   using OpConversionPattern::OpConversionPattern;
68 
69   LogicalResult
70   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
71                   ConversionPatternRewriter &rewriter) const override {
72     if (insertOp.getSourceType().isa<VectorType>() ||
73         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
74       return failure();
75     vector::InsertOp::Adaptor adaptor(operands);
76     int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
77     Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
78         insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
79     rewriter.replaceOp(insertOp, newInsert);
80     return success();
81   }
82 };
83 
84 struct VectorExtractElementOpConvert final
85     : public OpConversionPattern<vector::ExtractElementOp> {
86   using OpConversionPattern::OpConversionPattern;
87 
88   LogicalResult
89   matchAndRewrite(vector::ExtractElementOp extractElementOp,
90                   ArrayRef<Value> operands,
91                   ConversionPatternRewriter &rewriter) const override {
92     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
93       return failure();
94     vector::ExtractElementOp::Adaptor adaptor(operands);
95     Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
96         extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
97         extractElementOp.position());
98     rewriter.replaceOp(extractElementOp, newExtractElement);
99     return success();
100   }
101 };
102 
103 struct VectorInsertElementOpConvert final
104     : public OpConversionPattern<vector::InsertElementOp> {
105   using OpConversionPattern::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(vector::InsertElementOp insertElementOp,
109                   ArrayRef<Value> operands,
110                   ConversionPatternRewriter &rewriter) const override {
111     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
112       return failure();
113     vector::InsertElementOp::Adaptor adaptor(operands);
114     Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
115         insertElementOp.getLoc(), insertElementOp.getType(),
116         insertElementOp.dest(), adaptor.source(), insertElementOp.position());
117     rewriter.replaceOp(insertElementOp, newInsertElement);
118     return success();
119   }
120 };
121 
122 } // namespace
123 
124 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
125                                          SPIRVTypeConverter &typeConverter,
126                                          OwningRewritePatternList &patterns) {
127   patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
128                   VectorInsertOpConvert, VectorExtractElementOpConvert,
129                   VectorInsertElementOpConvert>(typeConverter, context);
130 }
131