1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct VectorBroadcastConvert final
27     : public SPIRVOpLowering<vector::BroadcastOp> {
28   using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
29   LogicalResult
30   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
31                   ConversionPatternRewriter &rewriter) const override {
32     if (broadcastOp.source().getType().isa<VectorType>() ||
33         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
34       return failure();
35     vector::BroadcastOp::Adaptor adaptor(operands);
36     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
37                                  adaptor.source());
38     Value construct = rewriter.create<spirv::CompositeConstructOp>(
39         broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
40     rewriter.replaceOp(broadcastOp, construct);
41     return success();
42   }
43 };
44 
45 struct VectorExtractOpConvert final
46     : public SPIRVOpLowering<vector::ExtractOp> {
47   using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
48   LogicalResult
49   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
50                   ConversionPatternRewriter &rewriter) const override {
51     if (extractOp.getType().isa<VectorType>() ||
52         !spirv::CompositeType::isValid(extractOp.getVectorType()))
53       return failure();
54     vector::ExtractOp::Adaptor adaptor(operands);
55     int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
56     Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
57         extractOp.getLoc(), adaptor.vector(), id);
58     rewriter.replaceOp(extractOp, newExtract);
59     return success();
60   }
61 };
62 
63 struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
64   using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
65   LogicalResult
66   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
67                   ConversionPatternRewriter &rewriter) const override {
68     if (insertOp.getSourceType().isa<VectorType>() ||
69         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
70       return failure();
71     vector::InsertOp::Adaptor adaptor(operands);
72     int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
73     Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
74         insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
75     rewriter.replaceOp(insertOp, newInsert);
76     return success();
77   }
78 };
79 
80 struct VectorExtractElementOpConvert final
81     : public SPIRVOpLowering<vector::ExtractElementOp> {
82   using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering;
83   LogicalResult
84   matchAndRewrite(vector::ExtractElementOp extractElementOp,
85                   ArrayRef<Value> operands,
86                   ConversionPatternRewriter &rewriter) const override {
87     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
88       return failure();
89     vector::ExtractElementOp::Adaptor adaptor(operands);
90     Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
91         extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
92         extractElementOp.position());
93     rewriter.replaceOp(extractElementOp, newExtractElement);
94     return success();
95   }
96 };
97 
98 struct VectorInsertElementOpConvert final
99     : public SPIRVOpLowering<vector::InsertElementOp> {
100   using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
101   LogicalResult
102   matchAndRewrite(vector::InsertElementOp insertElementOp,
103                   ArrayRef<Value> operands,
104                   ConversionPatternRewriter &rewriter) const override {
105     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
106       return failure();
107     vector::InsertElementOp::Adaptor adaptor(operands);
108     Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
109         insertElementOp.getLoc(), insertElementOp.getType(),
110         insertElementOp.dest(), adaptor.source(), insertElementOp.position());
111     rewriter.replaceOp(insertElementOp, newInsertElement);
112     return success();
113   }
114 };
115 
116 } // namespace
117 
118 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
119                                          SPIRVTypeConverter &typeConverter,
120                                          OwningRewritePatternList &patterns) {
121   patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
122                   VectorInsertOpConvert, VectorExtractElementOpConvert,
123                   VectorInsertElementOpConvert>(context, typeConverter);
124 }
125