1 //===- SCFToSPIRVPass.cpp - SCF to SPIR-V Passes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert SCF dialect into SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
14 
15 #include "../PassDetail.h"
16 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
17 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
18 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
19 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
22 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 struct SCFToSPIRVPass : public SCFToSPIRVBase<SCFToSPIRVPass> {
28   void runOnOperation() override;
29 };
30 } // namespace
31 
runOnOperation()32 void SCFToSPIRVPass::runOnOperation() {
33   MLIRContext *context = &getContext();
34   ModuleOp module = getOperation();
35 
36   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
37   std::unique_ptr<ConversionTarget> target =
38       SPIRVConversionTarget::get(targetAttr);
39 
40   SPIRVTypeConverter typeConverter(targetAttr);
41   ScfToSPIRVContext scfContext;
42   RewritePatternSet patterns(context);
43   populateSCFToSPIRVPatterns(typeConverter, scfContext, patterns);
44 
45   // TODO: Change SPIR-V conversion to be progressive and remove the following
46   // patterns.
47   mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
48   populateFuncToSPIRVPatterns(typeConverter, patterns);
49   populateMemRefToSPIRVPatterns(typeConverter, patterns);
50   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
51 
52   if (failed(applyPartialConversion(module, *target, std::move(patterns))))
53     return signalPassFailure();
54 }
55 
createConvertSCFToSPIRVPass()56 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToSPIRVPass() {
57   return std::make_unique<SCFToSPIRVPass>();
58 }
59