1 //===- VectorToSPIRVPass.cpp - Vector to SPIR-V Passes --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert Vector dialect to SPIRV dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct LowerVectorToSPIRVPass
26     : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
27   void runOnOperation() override;
28 };
29 } // namespace
30 
31 void LowerVectorToSPIRVPass::runOnOperation() {
32   MLIRContext *context = &getContext();
33   ModuleOp module = getOperation();
34 
35   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
36   std::unique_ptr<ConversionTarget> target =
37       SPIRVConversionTarget::get(targetAttr);
38 
39   SPIRVTypeConverter typeConverter(targetAttr);
40   RewritePatternSet patterns(context);
41   populateVectorToSPIRVPatterns(typeConverter, patterns);
42 
43   target->addLegalOp<ModuleOp>();
44   target->addLegalOp<FuncOp>();
45 
46   if (failed(applyFullConversion(module, *target, std::move(patterns))))
47     return signalPassFailure();
48 }
49 
50 std::unique_ptr<OperationPass<ModuleOp>>
51 mlir::createConvertVectorToSPIRVPass() {
52   return std::make_unique<LowerVectorToSPIRVPass>();
53 }
54