1 //===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a pass to convert a kernel function in the GPU Dialect 19 // into a spv.module operation 20 // 21 //===----------------------------------------------------------------------===// 22 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" 23 #include "mlir/Dialect/GPU/GPUDialect.h" 24 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 25 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 26 #include "mlir/Pass/Pass.h" 27 28 using namespace mlir; 29 30 namespace { 31 32 /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the 33 /// attribute gpu.kernel) within a spv.module. 34 class KernelFnConversion final : public SPIRVOpLowering<FuncOp> { 35 public: 36 using SPIRVOpLowering<FuncOp>::SPIRVOpLowering; 37 38 PatternMatchResult 39 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 40 ConversionPatternRewriter &rewriter) const override; 41 }; 42 } // namespace 43 44 PatternMatchResult 45 KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 46 ConversionPatternRewriter &rewriter) const { 47 auto funcOp = cast<FuncOp>(op); 48 FuncOp newFuncOp; 49 if (!gpu::GPUDialect::isKernel(funcOp)) { 50 return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter, 51 newFuncOp)) 52 ? matchSuccess() 53 : matchFailure(); 54 } 55 56 if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter, 57 newFuncOp))) { 58 return matchFailure(); 59 } 60 newFuncOp.getOperation()->removeAttr(Identifier::get( 61 gpu::GPUDialect::getKernelFuncAttrName(), op->getContext())); 62 return matchSuccess(); 63 } 64 65 namespace { 66 /// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions 67 /// that have the "gpu.kernel" attribute, i.e. those functions that are 68 /// referenced in gpu::LaunchKernelOp operations. For each such function 69 /// 70 /// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp 71 /// (the original function is still needed by the gpu::LaunchKernelOp, so cannot 72 /// replace it). 73 /// 74 /// 2) Lower the body of the spirv::ModuleOp. 75 class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> { 76 void runOnModule() override; 77 }; 78 } // namespace 79 80 void GPUToSPIRVPass::runOnModule() { 81 auto context = &getContext(); 82 auto module = getModule(); 83 84 SmallVector<Operation *, 4> spirvModules; 85 for (auto funcOp : module.getOps<FuncOp>()) { 86 if (gpu::GPUDialect::isKernel(funcOp)) { 87 OpBuilder builder(module.getBodyRegion()); 88 // Create a new spirv::ModuleOp for this function, and clone the 89 // function into it. 90 // TODO : Generalize this to account for different extensions, 91 // capabilities, extended_instruction_sets, other addressing models 92 // and memory models. 93 auto spvModule = builder.create<spirv::ModuleOp>( 94 funcOp.getLoc(), 95 builder.getI32IntegerAttr( 96 static_cast<int32_t>(spirv::AddressingModel::Logical)), 97 builder.getI32IntegerAttr( 98 static_cast<int32_t>(spirv::MemoryModel::VulkanKHR))); 99 OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0)); 100 moduleBuilder.clone(*funcOp.getOperation()); 101 spirvModules.push_back(spvModule); 102 } 103 } 104 105 /// Dialect conversion to lower the functions with the spirv::ModuleOps. 106 SPIRVBasicTypeConverter basicTypeConverter(context); 107 SPIRVTypeConverter typeConverter(&basicTypeConverter); 108 OwningRewritePatternList patterns; 109 patterns.insert<KernelFnConversion>(context, typeConverter); 110 populateStandardToSPIRVPatterns(context, patterns); 111 112 ConversionTarget target(*context); 113 target.addLegalDialect<spirv::SPIRVDialect>(); 114 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) { 115 return basicTypeConverter.isSignatureLegal(Op.getType()); 116 }); 117 118 if (failed(applyFullConversion(spirvModules, target, patterns, 119 &typeConverter))) { 120 return signalPassFailure(); 121 } 122 } 123 124 ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); } 125 126 static PassRegistration<GPUToSPIRVPass> 127 pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect"); 128