1 //===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements a pass to convert a kernel function in the GPU Dialect
19 // into a spv.module operation
20 //
21 //===----------------------------------------------------------------------===//
22 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
23 #include "mlir/Dialect/GPU/GPUDialect.h"
24 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
25 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
26 #include "mlir/Pass/Pass.h"
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
33 /// attribute gpu.kernel) within a spv.module.
34 class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
35 public:
36   using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
37 
38   PatternMatchResult
39   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
40                   ConversionPatternRewriter &rewriter) const override;
41 };
42 } // namespace
43 
44 PatternMatchResult
45 KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
46                                     ConversionPatternRewriter &rewriter) const {
47   auto funcOp = cast<FuncOp>(op);
48   FuncOp newFuncOp;
49   if (!gpu::GPUDialect::isKernel(funcOp)) {
50     return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
51                                    newFuncOp))
52                ? matchSuccess()
53                : matchFailure();
54   }
55 
56   if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
57                                   newFuncOp))) {
58     return matchFailure();
59   }
60   newFuncOp.getOperation()->removeAttr(Identifier::get(
61       gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
62   return matchSuccess();
63 }
64 
65 namespace {
66 /// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
67 /// that have the "gpu.kernel" attribute, i.e. those functions that are
68 /// referenced in gpu::LaunchKernelOp operations. For each such function
69 ///
70 /// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
71 /// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
72 /// replace it).
73 ///
74 /// 2) Lower the body of the spirv::ModuleOp.
75 class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
76   void runOnModule() override;
77 };
78 } // namespace
79 
80 void GPUToSPIRVPass::runOnModule() {
81   auto context = &getContext();
82   auto module = getModule();
83 
84   SmallVector<Operation *, 4> spirvModules;
85   for (auto funcOp : module.getOps<FuncOp>()) {
86     if (gpu::GPUDialect::isKernel(funcOp)) {
87       OpBuilder builder(module.getBodyRegion());
88       // Create a new spirv::ModuleOp for this function, and clone the
89       // function into it.
90       // TODO : Generalize this to account for different extensions,
91       // capabilities, extended_instruction_sets, other addressing models
92       // and memory models.
93       auto spvModule = builder.create<spirv::ModuleOp>(
94           funcOp.getLoc(),
95           builder.getI32IntegerAttr(
96               static_cast<int32_t>(spirv::AddressingModel::Logical)),
97           builder.getI32IntegerAttr(
98               static_cast<int32_t>(spirv::MemoryModel::VulkanKHR)));
99       OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
100       moduleBuilder.clone(*funcOp.getOperation());
101       spirvModules.push_back(spvModule);
102     }
103   }
104 
105   /// Dialect conversion to lower the functions with the spirv::ModuleOps.
106   SPIRVBasicTypeConverter basicTypeConverter(context);
107   SPIRVTypeConverter typeConverter(&basicTypeConverter);
108   OwningRewritePatternList patterns;
109   patterns.insert<KernelFnConversion>(context, typeConverter);
110   populateStandardToSPIRVPatterns(context, patterns);
111 
112   ConversionTarget target(*context);
113   target.addLegalDialect<spirv::SPIRVDialect>();
114   target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) {
115     return basicTypeConverter.isSignatureLegal(Op.getType());
116   });
117 
118   if (failed(applyFullConversion(spirvModules, target, patterns,
119                                  &typeConverter))) {
120     return signalPassFailure();
121   }
122 }
123 
124 ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
125 
126 static PassRegistration<GPUToSPIRVPass>
127     pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
128