1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 namespace { 26 struct VectorBroadcastConvert final 27 : public OpConversionPattern<vector::BroadcastOp> { 28 using OpConversionPattern::OpConversionPattern; 29 30 LogicalResult 31 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 32 ConversionPatternRewriter &rewriter) const override { 33 if (broadcastOp.source().getType().isa<VectorType>() || 34 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 35 return failure(); 36 vector::BroadcastOp::Adaptor adaptor(operands); 37 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 38 adaptor.source()); 39 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 40 broadcastOp, broadcastOp.getVectorType(), source); 41 return success(); 42 } 43 }; 44 45 struct VectorExtractOpConvert final 46 : public OpConversionPattern<vector::ExtractOp> { 47 using OpConversionPattern::OpConversionPattern; 48 49 LogicalResult 50 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 51 ConversionPatternRewriter &rewriter) const override { 52 if (extractOp.getType().isa<VectorType>() || 53 !spirv::CompositeType::isValid(extractOp.getVectorType())) 54 return failure(); 55 vector::ExtractOp::Adaptor adaptor(operands); 56 int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt(); 57 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 58 extractOp, adaptor.vector(), id); 59 return success(); 60 } 61 }; 62 63 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 64 using OpConversionPattern::OpConversionPattern; 65 66 LogicalResult 67 matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands, 68 ConversionPatternRewriter &rewriter) const override { 69 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 70 return failure(); 71 vector::FMAOp::Adaptor adaptor(operands); 72 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 73 fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); 74 return success(); 75 } 76 }; 77 78 struct VectorInsertOpConvert final 79 : public OpConversionPattern<vector::InsertOp> { 80 using OpConversionPattern::OpConversionPattern; 81 82 LogicalResult 83 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 84 ConversionPatternRewriter &rewriter) const override { 85 if (insertOp.getSourceType().isa<VectorType>() || 86 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 87 return failure(); 88 vector::InsertOp::Adaptor adaptor(operands); 89 int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt(); 90 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 91 insertOp, adaptor.source(), adaptor.dest(), id); 92 return success(); 93 } 94 }; 95 96 struct VectorExtractElementOpConvert final 97 : public OpConversionPattern<vector::ExtractElementOp> { 98 using OpConversionPattern::OpConversionPattern; 99 100 LogicalResult 101 matchAndRewrite(vector::ExtractElementOp extractElementOp, 102 ArrayRef<Value> operands, 103 ConversionPatternRewriter &rewriter) const override { 104 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 105 return failure(); 106 vector::ExtractElementOp::Adaptor adaptor(operands); 107 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 108 extractElementOp, extractElementOp.getType(), adaptor.vector(), 109 extractElementOp.position()); 110 return success(); 111 } 112 }; 113 114 struct VectorInsertElementOpConvert final 115 : public OpConversionPattern<vector::InsertElementOp> { 116 using OpConversionPattern::OpConversionPattern; 117 118 LogicalResult 119 matchAndRewrite(vector::InsertElementOp insertElementOp, 120 ArrayRef<Value> operands, 121 ConversionPatternRewriter &rewriter) const override { 122 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 123 return failure(); 124 vector::InsertElementOp::Adaptor adaptor(operands); 125 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 126 insertElementOp, insertElementOp.getType(), insertElementOp.dest(), 127 adaptor.source(), insertElementOp.position()); 128 return success(); 129 } 130 }; 131 132 } // namespace 133 134 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, 135 SPIRVTypeConverter &typeConverter, 136 OwningRewritePatternList &patterns) { 137 patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert, 138 VectorExtractOpConvert, VectorFmaOpConvert, 139 VectorInsertOpConvert, VectorInsertElementOpConvert>( 140 typeConverter, context); 141 } 142