1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 namespace { 26 struct VectorBroadcastConvert final 27 : public OpConversionPattern<vector::BroadcastOp> { 28 using OpConversionPattern::OpConversionPattern; 29 30 LogicalResult 31 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 32 ConversionPatternRewriter &rewriter) const override { 33 if (broadcastOp.source().getType().isa<VectorType>() || 34 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 35 return failure(); 36 vector::BroadcastOp::Adaptor adaptor(operands); 37 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 38 adaptor.source()); 39 Value construct = rewriter.create<spirv::CompositeConstructOp>( 40 broadcastOp.getLoc(), broadcastOp.getVectorType(), source); 41 rewriter.replaceOp(broadcastOp, construct); 42 return success(); 43 } 44 }; 45 46 struct VectorExtractOpConvert final 47 : public OpConversionPattern<vector::ExtractOp> { 48 using OpConversionPattern::OpConversionPattern; 49 50 LogicalResult 51 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 52 ConversionPatternRewriter &rewriter) const override { 53 if (extractOp.getType().isa<VectorType>() || 54 !spirv::CompositeType::isValid(extractOp.getVectorType())) 55 return failure(); 56 vector::ExtractOp::Adaptor adaptor(operands); 57 int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt(); 58 Value newExtract = rewriter.create<spirv::CompositeExtractOp>( 59 extractOp.getLoc(), adaptor.vector(), id); 60 rewriter.replaceOp(extractOp, newExtract); 61 return success(); 62 } 63 }; 64 65 struct VectorInsertOpConvert final 66 : public OpConversionPattern<vector::InsertOp> { 67 using OpConversionPattern::OpConversionPattern; 68 69 LogicalResult 70 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 71 ConversionPatternRewriter &rewriter) const override { 72 if (insertOp.getSourceType().isa<VectorType>() || 73 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 74 return failure(); 75 vector::InsertOp::Adaptor adaptor(operands); 76 int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt(); 77 Value newInsert = rewriter.create<spirv::CompositeInsertOp>( 78 insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); 79 rewriter.replaceOp(insertOp, newInsert); 80 return success(); 81 } 82 }; 83 84 struct VectorExtractElementOpConvert final 85 : public OpConversionPattern<vector::ExtractElementOp> { 86 using OpConversionPattern::OpConversionPattern; 87 88 LogicalResult 89 matchAndRewrite(vector::ExtractElementOp extractElementOp, 90 ArrayRef<Value> operands, 91 ConversionPatternRewriter &rewriter) const override { 92 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 93 return failure(); 94 vector::ExtractElementOp::Adaptor adaptor(operands); 95 Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>( 96 extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), 97 extractElementOp.position()); 98 rewriter.replaceOp(extractElementOp, newExtractElement); 99 return success(); 100 } 101 }; 102 103 struct VectorInsertElementOpConvert final 104 : public OpConversionPattern<vector::InsertElementOp> { 105 using OpConversionPattern::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(vector::InsertElementOp insertElementOp, 109 ArrayRef<Value> operands, 110 ConversionPatternRewriter &rewriter) const override { 111 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 112 return failure(); 113 vector::InsertElementOp::Adaptor adaptor(operands); 114 Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>( 115 insertElementOp.getLoc(), insertElementOp.getType(), 116 insertElementOp.dest(), adaptor.source(), insertElementOp.position()); 117 rewriter.replaceOp(insertElementOp, newInsertElement); 118 return success(); 119 } 120 }; 121 122 } // namespace 123 124 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, 125 SPIRVTypeConverter &typeConverter, 126 OwningRewritePatternList &patterns) { 127 patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert, 128 VectorInsertOpConvert, VectorExtractElementOpConvert, 129 VectorInsertElementOpConvert>(typeConverter, context); 130 } 131