1 //===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to generate SPIRV operations for Vector 10 // operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h" 16 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h" 17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 19 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 20 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 27 namespace { 28 struct VectorBroadcastConvert final 29 : public SPIRVOpLowering<vector::BroadcastOp> { 30 using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering; 31 LogicalResult 32 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 33 ConversionPatternRewriter &rewriter) const override { 34 if (broadcastOp.source().getType().isa<VectorType>() || 35 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 36 return failure(); 37 vector::BroadcastOp::Adaptor adaptor(operands); 38 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 39 adaptor.source()); 40 Value construct = rewriter.create<spirv::CompositeConstructOp>( 41 broadcastOp.getLoc(), broadcastOp.getVectorType(), source); 42 rewriter.replaceOp(broadcastOp, construct); 43 return success(); 44 } 45 }; 46 47 struct VectorExtractOpConvert final 48 : public SPIRVOpLowering<vector::ExtractOp> { 49 using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering; 50 LogicalResult 51 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 52 ConversionPatternRewriter &rewriter) const override { 53 if (extractOp.getType().isa<VectorType>() || 54 !spirv::CompositeType::isValid(extractOp.getVectorType())) 55 return failure(); 56 vector::ExtractOp::Adaptor adaptor(operands); 57 int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt(); 58 Value newExtract = rewriter.create<spirv::CompositeExtractOp>( 59 extractOp.getLoc(), adaptor.vector(), id); 60 rewriter.replaceOp(extractOp, newExtract); 61 return success(); 62 } 63 }; 64 65 struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> { 66 using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering; 67 LogicalResult 68 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 69 ConversionPatternRewriter &rewriter) const override { 70 if (insertOp.getSourceType().isa<VectorType>() || 71 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 72 return failure(); 73 vector::InsertOp::Adaptor adaptor(operands); 74 int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt(); 75 Value newInsert = rewriter.create<spirv::CompositeInsertOp>( 76 insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); 77 rewriter.replaceOp(insertOp, newInsert); 78 return success(); 79 } 80 }; 81 82 struct VectorExtractElementOpConvert final 83 : public SPIRVOpLowering<vector::ExtractElementOp> { 84 using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering; 85 LogicalResult 86 matchAndRewrite(vector::ExtractElementOp extractElementOp, 87 ArrayRef<Value> operands, 88 ConversionPatternRewriter &rewriter) const override { 89 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 90 return failure(); 91 vector::ExtractElementOp::Adaptor adaptor(operands); 92 Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>( 93 extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(), 94 extractElementOp.position()); 95 rewriter.replaceOp(extractElementOp, newExtractElement); 96 return success(); 97 } 98 }; 99 100 struct VectorInsertElementOpConvert final 101 : public SPIRVOpLowering<vector::InsertElementOp> { 102 using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering; 103 LogicalResult 104 matchAndRewrite(vector::InsertElementOp insertElementOp, 105 ArrayRef<Value> operands, 106 ConversionPatternRewriter &rewriter) const override { 107 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 108 return failure(); 109 vector::InsertElementOp::Adaptor adaptor(operands); 110 Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>( 111 insertElementOp.getLoc(), insertElementOp.getType(), 112 insertElementOp.dest(), adaptor.source(), insertElementOp.position()); 113 rewriter.replaceOp(insertElementOp, newInsertElement); 114 return success(); 115 } 116 }; 117 118 } // namespace 119 120 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, 121 SPIRVTypeConverter &typeConverter, 122 OwningRewritePatternList &patterns) { 123 patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert, 124 VectorInsertOpConvert, VectorExtractElementOpConvert, 125 VectorInsertElementOpConvert>(context, typeConverter); 126 } 127 128 namespace { 129 struct LowerVectorToSPIRVPass 130 : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> { 131 void runOnOperation() override; 132 }; 133 } // namespace 134 135 void LowerVectorToSPIRVPass::runOnOperation() { 136 MLIRContext *context = &getContext(); 137 ModuleOp module = getOperation(); 138 139 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 140 std::unique_ptr<ConversionTarget> target = 141 spirv::SPIRVConversionTarget::get(targetAttr); 142 143 SPIRVTypeConverter typeConverter(targetAttr); 144 OwningRewritePatternList patterns; 145 populateVectorToSPIRVPatterns(context, typeConverter, patterns); 146 147 target->addLegalOp<ModuleOp, ModuleTerminatorOp>(); 148 target->addLegalOp<FuncOp>(); 149 150 if (failed(applyFullConversion(module, *target, std::move(patterns)))) 151 return signalPassFailure(); 152 } 153 154 std::unique_ptr<OperationPass<ModuleOp>> 155 mlir::createConvertVectorToSPIRVPass() { 156 return std::make_unique<LowerVectorToSPIRVPass>(); 157 } 158