1 //===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to generate SPIRV operations for Vector 10 // operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h" 16 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h" 17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 18 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 19 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 20 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 27 namespace { 28 struct VectorBroadcastConvert final 29 : public SPIRVOpLowering<vector::BroadcastOp> { 30 using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering; 31 LogicalResult 32 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands, 33 ConversionPatternRewriter &rewriter) const override { 34 if (broadcastOp.source().getType().isa<VectorType>() || 35 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 36 return failure(); 37 vector::BroadcastOp::Adaptor adaptor(operands); 38 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 39 adaptor.source()); 40 Value construct = rewriter.create<spirv::CompositeConstructOp>( 41 broadcastOp.getLoc(), broadcastOp.getVectorType(), source); 42 rewriter.replaceOp(broadcastOp, construct); 43 return success(); 44 } 45 }; 46 47 struct VectorExtractOpConvert final 48 : public SPIRVOpLowering<vector::ExtractOp> { 49 using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering; 50 LogicalResult 51 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands, 52 ConversionPatternRewriter &rewriter) const override { 53 if (extractOp.getType().isa<VectorType>() || 54 !spirv::CompositeType::isValid(extractOp.getVectorType())) 55 return failure(); 56 vector::ExtractOp::Adaptor adaptor(operands); 57 int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt(); 58 Value newExtract = rewriter.create<spirv::CompositeExtractOp>( 59 extractOp.getLoc(), adaptor.vector(), id); 60 rewriter.replaceOp(extractOp, newExtract); 61 return success(); 62 } 63 }; 64 65 struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> { 66 using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering; 67 LogicalResult 68 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands, 69 ConversionPatternRewriter &rewriter) const override { 70 if (insertOp.getSourceType().isa<VectorType>() || 71 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 72 return failure(); 73 vector::InsertOp::Adaptor adaptor(operands); 74 int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt(); 75 Value newInsert = rewriter.create<spirv::CompositeInsertOp>( 76 insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); 77 rewriter.replaceOp(insertOp, newInsert); 78 return success(); 79 } 80 }; 81 } // namespace 82 83 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, 84 SPIRVTypeConverter &typeConverter, 85 OwningRewritePatternList &patterns) { 86 patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert, 87 VectorInsertOpConvert>(context, typeConverter); 88 } 89 90 namespace { 91 struct LowerVectorToSPIRVPass 92 : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> { 93 void runOnOperation() override; 94 }; 95 } // namespace 96 97 void LowerVectorToSPIRVPass::runOnOperation() { 98 MLIRContext *context = &getContext(); 99 ModuleOp module = getOperation(); 100 101 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 102 std::unique_ptr<ConversionTarget> target = 103 spirv::SPIRVConversionTarget::get(targetAttr); 104 105 SPIRVTypeConverter typeConverter(targetAttr); 106 OwningRewritePatternList patterns; 107 populateVectorToSPIRVPatterns(context, typeConverter, patterns); 108 109 target->addLegalOp<ModuleOp, ModuleTerminatorOp>(); 110 target->addLegalOp<FuncOp>(); 111 112 if (failed(applyFullConversion(module, *target, std::move(patterns)))) 113 return signalPassFailure(); 114 } 115 116 std::unique_ptr<OperationPass<ModuleOp>> 117 mlir::createConvertVectorToSPIRVPass() { 118 return std::make_unique<LowerVectorToSPIRVPass>(); 119 } 120