1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 namespace { 26 struct VectorBroadcastConvert final 27 : public SPIRVOpLowering<vector::BroadcastOp> { 28 using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering; 29 LogicalResult 30 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 31 ConversionPatternRewriter &rewriter) const override { 32 if (broadcastOp.source().getType().isa<VectorType>() || 33 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 34 return failure(); 35 vector::BroadcastOp::Adaptor adaptor(operands); 36 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 37 adaptor.source()); 38 Value construct = rewriter.create<spirv::CompositeConstructOp>( 39 broadcastOp.getLoc(), broadcastOp.getVectorType(), source); 40 rewriter.replaceOp(broadcastOp, construct); 41 return success(); 42 } 43 }; 44 45 struct VectorExtractOpConvert final 46 : public SPIRVOpLowering<vector::ExtractOp> { 47 using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering; 48 LogicalResult 49 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 50 ConversionPatternRewriter &rewriter) const override { 51 if (extractOp.getType().isa<VectorType>() || 52 !spirv::CompositeType::isValid(extractOp.getVectorType())) 53 return failure(); 54 vector::ExtractOp::Adaptor adaptor(operands); 55 int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt(); 56 Value newExtract = rewriter.create<spirv::CompositeExtractOp>( 57 extractOp.getLoc(), adaptor.vector(), id); 58 rewriter.replaceOp(extractOp, newExtract); 59 return success(); 60 } 61 }; 62 63 struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> { 64 using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering; 65 LogicalResult 66 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 67 ConversionPatternRewriter &rewriter) const override { 68 if (insertOp.getSourceType().isa<VectorType>() || 69 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 70 return failure(); 71 vector::InsertOp::Adaptor adaptor(operands); 72 int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt(); 73 Value newInsert = rewriter.create<spirv::CompositeInsertOp>( 74 insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); 75 rewriter.replaceOp(insertOp, newInsert); 76 return success(); 77 } 78 }; 79 80 struct VectorExtractElementOpConvert final 81 : public SPIRVOpLowering<vector::ExtractElementOp> { 82 using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering; 83 LogicalResult 84 matchAndRewrite(vector::ExtractElementOp extractElementOp, 85 ArrayRef<Value> operands, 86 ConversionPatternRewriter &rewriter) const override { 87 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 88 return failure(); 89 vector::ExtractElementOp::Adaptor adaptor(operands); 90 Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>( 91 extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), 92 extractElementOp.position()); 93 rewriter.replaceOp(extractElementOp, newExtractElement); 94 return success(); 95 } 96 }; 97 98 struct VectorInsertElementOpConvert final 99 : public SPIRVOpLowering<vector::InsertElementOp> { 100 using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering; 101 LogicalResult 102 matchAndRewrite(vector::InsertElementOp insertElementOp, 103 ArrayRef<Value> operands, 104 ConversionPatternRewriter &rewriter) const override { 105 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 106 return failure(); 107 vector::InsertElementOp::Adaptor adaptor(operands); 108 Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>( 109 insertElementOp.getLoc(), insertElementOp.getType(), 110 insertElementOp.dest(), adaptor.source(), insertElementOp.position()); 111 rewriter.replaceOp(insertElementOp, newInsertElement); 112 return success(); 113 } 114 }; 115 116 } // namespace 117 118 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, 119 SPIRVTypeConverter &typeConverter, 120 OwningRewritePatternList &patterns) { 121 patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert, 122 VectorInsertOpConvert, VectorExtractElementOpConvert, 123 VectorInsertElementOpConvert>(context, typeConverter); 124 } 125