1 //===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements a pass to convert a kernel function in the GPU Dialect
19 // into a spv.module operation
20 //
21 //===----------------------------------------------------------------------===//
22 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
23 #include "mlir/Dialect/GPU/GPUDialect.h"
24 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
25 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
26 #include "mlir/Pass/Pass.h"
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
33 /// builin variables.
34 template <typename OpTy, spirv::BuiltIn builtin>
35 class LaunchConfigConversion : public SPIRVOpLowering<OpTy> {
36 public:
37   using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
38 
39   PatternMatchResult
40   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
41                   ConversionPatternRewriter &rewriter) const override;
42 };
43 
44 /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
45 /// attribute gpu.kernel) within a spv.module.
46 class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
47 public:
48   using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
49 
50   PatternMatchResult
51   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
52                   ConversionPatternRewriter &rewriter) const override;
53 };
54 } // namespace
55 
56 template <typename OpTy, spirv::BuiltIn builtin>
57 PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
58     Operation *op, ArrayRef<Value *> operands,
59     ConversionPatternRewriter &rewriter) const {
60   auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
61   if (!dimAttr) {
62     return this->matchFailure();
63   }
64   int32_t index = 0;
65   if (dimAttr.getValue() == "x") {
66     index = 0;
67   } else if (dimAttr.getValue() == "y") {
68     index = 1;
69   } else if (dimAttr.getValue() == "z") {
70     index = 2;
71   } else {
72     return this->matchFailure();
73   }
74 
75   // SPIR-V invocation builtin variables are a vector of type <3xi32>
76   auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter);
77   rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
78       op, rewriter.getIntegerType(32), spirvBuiltin,
79       rewriter.getI32ArrayAttr({index}));
80   return this->matchSuccess();
81 }
82 
83 PatternMatchResult
84 KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
85                                     ConversionPatternRewriter &rewriter) const {
86   auto funcOp = cast<FuncOp>(op);
87   FuncOp newFuncOp;
88   if (!gpu::GPUDialect::isKernel(funcOp)) {
89     return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
90                                    newFuncOp))
91                ? matchSuccess()
92                : matchFailure();
93   }
94 
95   if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
96                                   newFuncOp))) {
97     return matchFailure();
98   }
99   newFuncOp.getOperation()->removeAttr(Identifier::get(
100       gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
101   return matchSuccess();
102 }
103 
104 namespace {
105 /// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
106 /// that have the "gpu.kernel" attribute, i.e. those functions that are
107 /// referenced in gpu::LaunchKernelOp operations. For each such function
108 ///
109 /// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
110 /// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
111 /// replace it).
112 ///
113 /// 2) Lower the body of the spirv::ModuleOp.
114 class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
115   void runOnModule() override;
116 };
117 } // namespace
118 
119 void GPUToSPIRVPass::runOnModule() {
120   auto context = &getContext();
121   auto module = getModule();
122 
123   SmallVector<Operation *, 4> spirvModules;
124   for (auto funcOp : module.getOps<FuncOp>()) {
125     if (gpu::GPUDialect::isKernel(funcOp)) {
126       OpBuilder builder(module.getBodyRegion());
127       // Create a new spirv::ModuleOp for this function, and clone the
128       // function into it.
129       // TODO : Generalize this to account for different extensions,
130       // capabilities, extended_instruction_sets, other addressing models
131       // and memory models.
132       auto spvModule = builder.create<spirv::ModuleOp>(
133           funcOp.getLoc(),
134           builder.getI32IntegerAttr(
135               static_cast<int32_t>(spirv::AddressingModel::Logical)),
136           builder.getI32IntegerAttr(
137               static_cast<int32_t>(spirv::MemoryModel::VulkanKHR)));
138       OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
139       moduleBuilder.clone(*funcOp.getOperation());
140       spirvModules.push_back(spvModule);
141     }
142   }
143 
144   /// Dialect conversion to lower the functions with the spirv::ModuleOps.
145   SPIRVBasicTypeConverter basicTypeConverter(context);
146   SPIRVTypeConverter typeConverter(&basicTypeConverter);
147   OwningRewritePatternList patterns;
148   patterns.insert<
149       KernelFnConversion,
150       LaunchConfigConversion<gpu::BlockDim, spirv::BuiltIn::WorkgroupSize>,
151       LaunchConfigConversion<gpu::BlockId, spirv::BuiltIn::WorkgroupId>,
152       LaunchConfigConversion<gpu::GridDim, spirv::BuiltIn::NumWorkgroups>,
153       LaunchConfigConversion<gpu::ThreadId, spirv::BuiltIn::LocalInvocationId>>(
154       context, typeConverter);
155   populateStandardToSPIRVPatterns(context, patterns);
156 
157   ConversionTarget target(*context);
158   target.addLegalDialect<spirv::SPIRVDialect>();
159   target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) {
160     return basicTypeConverter.isSignatureLegal(Op.getType());
161   });
162 
163   if (failed(applyFullConversion(spirvModules, target, patterns,
164                                  &typeConverter))) {
165     return signalPassFailure();
166   }
167 }
168 
169 ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
170 
171 static PassRegistration<GPUToSPIRVPass>
172     pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
173