//===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements a pass to convert a kernel function in the GPU Dialect // into a spv.module operation // //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; namespace { /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation /// builin variables. template class LaunchConfigConversion : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the /// attribute gpu.kernel) within a spv.module. class KernelFnConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace template PatternMatchResult LaunchConfigConversion::matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto dimAttr = op->getAttrOfType("dimension"); if (!dimAttr) { return this->matchFailure(); } int32_t index = 0; if (dimAttr.getValue() == "x") { index = 0; } else if (dimAttr.getValue() == "y") { index = 1; } else if (dimAttr.getValue() == "z") { index = 2; } else { return this->matchFailure(); } // SPIR-V invocation builtin variables are a vector of type <3xi32> auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter); rewriter.replaceOpWithNewOp( op, rewriter.getIntegerType(32), spirvBuiltin, rewriter.getI32ArrayAttr({index})); return this->matchSuccess(); } PatternMatchResult KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto funcOp = cast(op); FuncOp newFuncOp; if (!gpu::GPUDialect::isKernel(funcOp)) { return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter, newFuncOp)) ? matchSuccess() : matchFailure(); } if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter, newFuncOp))) { return matchFailure(); } newFuncOp.getOperation()->removeAttr(Identifier::get( gpu::GPUDialect::getKernelFuncAttrName(), op->getContext())); return matchSuccess(); } namespace { /// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions /// that have the "gpu.kernel" attribute, i.e. those functions that are /// referenced in gpu::LaunchKernelOp operations. For each such function /// /// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp /// (the original function is still needed by the gpu::LaunchKernelOp, so cannot /// replace it). /// /// 2) Lower the body of the spirv::ModuleOp. class GPUToSPIRVPass : public ModulePass { void runOnModule() override; }; } // namespace void GPUToSPIRVPass::runOnModule() { auto context = &getContext(); auto module = getModule(); SmallVector spirvModules; for (auto funcOp : module.getOps()) { if (gpu::GPUDialect::isKernel(funcOp)) { OpBuilder builder(module.getBodyRegion()); // Create a new spirv::ModuleOp for this function, and clone the // function into it. // TODO : Generalize this to account for different extensions, // capabilities, extended_instruction_sets, other addressing models // and memory models. auto spvModule = builder.create( funcOp.getLoc(), builder.getI32IntegerAttr( static_cast(spirv::AddressingModel::Logical)), builder.getI32IntegerAttr( static_cast(spirv::MemoryModel::VulkanKHR))); OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0)); moduleBuilder.clone(*funcOp.getOperation()); spirvModules.push_back(spvModule); } } /// Dialect conversion to lower the functions with the spirv::ModuleOps. SPIRVBasicTypeConverter basicTypeConverter(context); SPIRVTypeConverter typeConverter(&basicTypeConverter); OwningRewritePatternList patterns; patterns.insert< KernelFnConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion>( context, typeConverter); populateStandardToSPIRVPatterns(context, patterns); ConversionTarget target(*context); target.addLegalDialect(); target.addDynamicallyLegalOp([&](FuncOp Op) { return basicTypeConverter.isSignatureLegal(Op.getType()); }); if (failed(applyFullConversion(spirvModules, target, patterns, &typeConverter))) { return signalPassFailure(); } } ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); } static PassRegistration pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");