1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct VectorBroadcastConvert final
27     : public OpConversionPattern<vector::BroadcastOp> {
28   using OpConversionPattern::OpConversionPattern;
29 
30   LogicalResult
31   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
32                   ConversionPatternRewriter &rewriter) const override {
33     if (broadcastOp.source().getType().isa<VectorType>() ||
34         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
35       return failure();
36     vector::BroadcastOp::Adaptor adaptor(operands);
37     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
38                                  adaptor.source());
39     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
40         broadcastOp, broadcastOp.getVectorType(), source);
41     return success();
42   }
43 };
44 
45 struct VectorExtractOpConvert final
46     : public OpConversionPattern<vector::ExtractOp> {
47   using OpConversionPattern::OpConversionPattern;
48 
49   LogicalResult
50   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
51                   ConversionPatternRewriter &rewriter) const override {
52     if (extractOp.getType().isa<VectorType>() ||
53         !spirv::CompositeType::isValid(extractOp.getVectorType()))
54       return failure();
55     vector::ExtractOp::Adaptor adaptor(operands);
56     int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
57     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
58         extractOp, adaptor.vector(), id);
59     return success();
60   }
61 };
62 
63 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
64   using OpConversionPattern::OpConversionPattern;
65 
66   LogicalResult
67   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
68                   ConversionPatternRewriter &rewriter) const override {
69     if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
70       return failure();
71     vector::FMAOp::Adaptor adaptor(operands);
72     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
73         fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
74     return success();
75   }
76 };
77 
78 struct VectorInsertOpConvert final
79     : public OpConversionPattern<vector::InsertOp> {
80   using OpConversionPattern::OpConversionPattern;
81 
82   LogicalResult
83   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
84                   ConversionPatternRewriter &rewriter) const override {
85     if (insertOp.getSourceType().isa<VectorType>() ||
86         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
87       return failure();
88     vector::InsertOp::Adaptor adaptor(operands);
89     int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
90     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
91         insertOp, adaptor.source(), adaptor.dest(), id);
92     return success();
93   }
94 };
95 
96 struct VectorExtractElementOpConvert final
97     : public OpConversionPattern<vector::ExtractElementOp> {
98   using OpConversionPattern::OpConversionPattern;
99 
100   LogicalResult
101   matchAndRewrite(vector::ExtractElementOp extractElementOp,
102                   ArrayRef<Value> operands,
103                   ConversionPatternRewriter &rewriter) const override {
104     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
105       return failure();
106     vector::ExtractElementOp::Adaptor adaptor(operands);
107     rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
108         extractElementOp, extractElementOp.getType(), adaptor.vector(),
109         extractElementOp.position());
110     return success();
111   }
112 };
113 
114 struct VectorInsertElementOpConvert final
115     : public OpConversionPattern<vector::InsertElementOp> {
116   using OpConversionPattern::OpConversionPattern;
117 
118   LogicalResult
119   matchAndRewrite(vector::InsertElementOp insertElementOp,
120                   ArrayRef<Value> operands,
121                   ConversionPatternRewriter &rewriter) const override {
122     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
123       return failure();
124     vector::InsertElementOp::Adaptor adaptor(operands);
125     rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
126         insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
127         adaptor.source(), insertElementOp.position());
128     return success();
129   }
130 };
131 
132 } // namespace
133 
134 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
135                                          SPIRVTypeConverter &typeConverter,
136                                          OwningRewritePatternList &patterns) {
137   patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert,
138                   VectorExtractOpConvert, VectorFmaOpConvert,
139                   VectorInsertOpConvert, VectorInsertElementOpConvert>(
140       typeConverter, context);
141 }
142