160965b46SAlex Zinenko //===- KernelOutlining.cpp - Implementation of GPU kernel outling ---------===// 260965b46SAlex Zinenko // 360965b46SAlex Zinenko // Copyright 2019 The MLIR Authors. 460965b46SAlex Zinenko // 560965b46SAlex Zinenko // Licensed under the Apache License, Version 2.0 (the "License"); 660965b46SAlex Zinenko // you may not use this file except in compliance with the License. 760965b46SAlex Zinenko // You may obtain a copy of the License at 860965b46SAlex Zinenko // 960965b46SAlex Zinenko // http://www.apache.org/licenses/LICENSE-2.0 1060965b46SAlex Zinenko // 1160965b46SAlex Zinenko // Unless required by applicable law or agreed to in writing, software 1260965b46SAlex Zinenko // distributed under the License is distributed on an "AS IS" BASIS, 1360965b46SAlex Zinenko // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1460965b46SAlex Zinenko // See the License for the specific language governing permissions and 1560965b46SAlex Zinenko // limitations under the License. 1660965b46SAlex Zinenko // ============================================================================= 1760965b46SAlex Zinenko // 1860965b46SAlex Zinenko // This file implements the GPU dialect kernel outlining pass. 1960965b46SAlex Zinenko // 2060965b46SAlex Zinenko //===----------------------------------------------------------------------===// 2160965b46SAlex Zinenko 2260965b46SAlex Zinenko #include "mlir/Dialect/GPU/GPUDialect.h" 2360965b46SAlex Zinenko #include "mlir/Dialect/GPU/Passes.h" 24ba0fa925SRiver Riddle #include "mlir/Dialect/StandardOps/Ops.h" 2560965b46SAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h" 2660965b46SAlex Zinenko #include "mlir/IR/Builders.h" 2760965b46SAlex Zinenko #include "mlir/Pass/Pass.h" 2860965b46SAlex Zinenko 2960965b46SAlex Zinenko using namespace mlir; 3060965b46SAlex Zinenko 3160965b46SAlex Zinenko template <typename OpTy> 3260965b46SAlex Zinenko static void createForAllDimensions(OpBuilder &builder, Location loc, 3360965b46SAlex Zinenko SmallVectorImpl<Value *> &values) { 3460965b46SAlex Zinenko for (StringRef dim : {"x", "y", "z"}) { 3560965b46SAlex Zinenko Value *v = builder.create<OpTy>(loc, builder.getIndexType(), 3660965b46SAlex Zinenko builder.getStringAttr(dim)); 3760965b46SAlex Zinenko values.push_back(v); 3860965b46SAlex Zinenko } 3960965b46SAlex Zinenko } 4060965b46SAlex Zinenko 4160965b46SAlex Zinenko // Add operations generating block/thread ids and gird/block dimensions at the 4260965b46SAlex Zinenko // beginning of `kernelFunc` and replace uses of the respective function args. 4360965b46SAlex Zinenko static void injectGpuIndexOperations(Location loc, FuncOp kernelFunc) { 4460965b46SAlex Zinenko OpBuilder OpBuilder(kernelFunc.getBody()); 4560965b46SAlex Zinenko SmallVector<Value *, 12> indexOps; 4660965b46SAlex Zinenko createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps); 4760965b46SAlex Zinenko createForAllDimensions<gpu::ThreadId>(OpBuilder, loc, indexOps); 4860965b46SAlex Zinenko createForAllDimensions<gpu::GridDim>(OpBuilder, loc, indexOps); 4960965b46SAlex Zinenko createForAllDimensions<gpu::BlockDim>(OpBuilder, loc, indexOps); 5060965b46SAlex Zinenko // Replace the leading 12 function args with the respective thread/block index 5160965b46SAlex Zinenko // operations. Iterate backwards since args are erased and indices change. 5260965b46SAlex Zinenko for (int i = 11; i >= 0; --i) { 5360965b46SAlex Zinenko auto &firstBlock = kernelFunc.front(); 5460965b46SAlex Zinenko firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]); 5560965b46SAlex Zinenko firstBlock.eraseArgument(i); 5660965b46SAlex Zinenko } 5760965b46SAlex Zinenko } 5860965b46SAlex Zinenko 59*dfd06af5SStephan Herhut // Move all constant arguments of the given kernel function into the function, 60*dfd06af5SStephan Herhut // thereby reducing the number of kernel arguments. 61*dfd06af5SStephan Herhut static gpu::LaunchFuncOp inlineConstants(FuncOp kernelFunc, 62*dfd06af5SStephan Herhut gpu::LaunchFuncOp launch) { 63*dfd06af5SStephan Herhut OpBuilder kernelBuilder(kernelFunc.getBody()); 64*dfd06af5SStephan Herhut auto &firstBlock = kernelFunc.getBody().front(); 65*dfd06af5SStephan Herhut llvm::SmallVector<Value *, 8> newLaunchArgs; 66*dfd06af5SStephan Herhut for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) { 67*dfd06af5SStephan Herhut auto operandOp = launch.getKernelOperand(i)->getDefiningOp(); 68*dfd06af5SStephan Herhut auto constant = dyn_cast_or_null<ConstantOp>(operandOp); 69*dfd06af5SStephan Herhut if (!constant) { 70*dfd06af5SStephan Herhut newLaunchArgs.push_back(launch.getKernelOperand(i)); 71*dfd06af5SStephan Herhut continue; 72*dfd06af5SStephan Herhut } 73*dfd06af5SStephan Herhut auto newConstant = kernelBuilder.clone(*operandOp); 74*dfd06af5SStephan Herhut firstBlock.getArgument(i)->replaceAllUsesWith(newConstant->getResult(0)); 75*dfd06af5SStephan Herhut firstBlock.eraseArgument(i); 76*dfd06af5SStephan Herhut } 77*dfd06af5SStephan Herhut if (newLaunchArgs.size() != launch.getNumKernelOperands()) { 78*dfd06af5SStephan Herhut std::reverse(newLaunchArgs.begin(), newLaunchArgs.end()); 79*dfd06af5SStephan Herhut OpBuilder LaunchBuilder(launch); 80*dfd06af5SStephan Herhut SmallVector<Type, 8> newArgumentTypes; 81*dfd06af5SStephan Herhut newArgumentTypes.reserve(firstBlock.getNumArguments()); 82*dfd06af5SStephan Herhut for (auto value : firstBlock.getArguments()) { 83*dfd06af5SStephan Herhut newArgumentTypes.push_back(value->getType()); 84*dfd06af5SStephan Herhut } 85*dfd06af5SStephan Herhut kernelFunc.setType(LaunchBuilder.getFunctionType(newArgumentTypes, {})); 86*dfd06af5SStephan Herhut auto newLaunch = LaunchBuilder.create<gpu::LaunchFuncOp>( 87*dfd06af5SStephan Herhut launch.getLoc(), kernelFunc, launch.getGridSizeOperandValues(), 88*dfd06af5SStephan Herhut launch.getBlockSizeOperandValues(), newLaunchArgs); 89*dfd06af5SStephan Herhut launch.erase(); 90*dfd06af5SStephan Herhut return newLaunch; 91*dfd06af5SStephan Herhut } 92*dfd06af5SStephan Herhut return launch; 93*dfd06af5SStephan Herhut } 94*dfd06af5SStephan Herhut 9560965b46SAlex Zinenko // Outline the `gpu.launch` operation body into a kernel function. Replace 9660965b46SAlex Zinenko // `gpu.return` operations by `std.return` in the generated functions. 9760965b46SAlex Zinenko static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { 9860965b46SAlex Zinenko Location loc = launchOp.getLoc(); 9960965b46SAlex Zinenko SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes()); 10060965b46SAlex Zinenko FunctionType type = 10160965b46SAlex Zinenko FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); 10260965b46SAlex Zinenko std::string kernelFuncName = 10360965b46SAlex Zinenko Twine(launchOp.getParentOfType<FuncOp>().getName(), "_kernel").str(); 10460965b46SAlex Zinenko FuncOp outlinedFunc = FuncOp::create(loc, kernelFuncName, type); 10560965b46SAlex Zinenko outlinedFunc.getBody().takeBody(launchOp.getBody()); 10660965b46SAlex Zinenko Builder builder(launchOp.getContext()); 10760965b46SAlex Zinenko outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), 10860965b46SAlex Zinenko builder.getUnitAttr()); 10960965b46SAlex Zinenko injectGpuIndexOperations(loc, outlinedFunc); 1104bfae66dSRiver Riddle outlinedFunc.walk([](mlir::gpu::Return op) { 11160965b46SAlex Zinenko OpBuilder replacer(op); 11260965b46SAlex Zinenko replacer.create<ReturnOp>(op.getLoc()); 11360965b46SAlex Zinenko op.erase(); 11460965b46SAlex Zinenko }); 11560965b46SAlex Zinenko return outlinedFunc; 11660965b46SAlex Zinenko } 11760965b46SAlex Zinenko 11860965b46SAlex Zinenko // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching 119*dfd06af5SStephan Herhut // `kernelFunc`. The kernel func contains the body of the `gpu.launch` with 120*dfd06af5SStephan Herhut // constant region arguments inlined. 12160965b46SAlex Zinenko static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, FuncOp kernelFunc) { 12260965b46SAlex Zinenko OpBuilder builder(launchOp); 12360965b46SAlex Zinenko SmallVector<Value *, 4> kernelOperandValues( 12460965b46SAlex Zinenko launchOp.getKernelOperandValues()); 125*dfd06af5SStephan Herhut auto launchFuncOp = builder.create<gpu::LaunchFuncOp>( 12660965b46SAlex Zinenko launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), 12760965b46SAlex Zinenko launchOp.getBlockSizeOperandValues(), kernelOperandValues); 128*dfd06af5SStephan Herhut inlineConstants(kernelFunc, launchFuncOp); 12960965b46SAlex Zinenko launchOp.erase(); 13060965b46SAlex Zinenko } 13160965b46SAlex Zinenko 13260965b46SAlex Zinenko namespace { 13360965b46SAlex Zinenko 13460965b46SAlex Zinenko class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> { 13560965b46SAlex Zinenko public: 13660965b46SAlex Zinenko void runOnModule() override { 13760965b46SAlex Zinenko ModuleManager moduleManager(getModule()); 13860965b46SAlex Zinenko for (auto func : getModule().getOps<FuncOp>()) { 1394bfae66dSRiver Riddle func.walk([&](mlir::gpu::LaunchOp op) { 14060965b46SAlex Zinenko FuncOp outlinedFunc = outlineKernelFunc(op); 14160965b46SAlex Zinenko moduleManager.insert(outlinedFunc); 14260965b46SAlex Zinenko convertToLaunchFuncOp(op, outlinedFunc); 14360965b46SAlex Zinenko }); 14460965b46SAlex Zinenko } 14560965b46SAlex Zinenko } 14660965b46SAlex Zinenko }; 14760965b46SAlex Zinenko 14860965b46SAlex Zinenko } // namespace 14960965b46SAlex Zinenko 150926fb685SMehdi Amini std::unique_ptr<ModulePassBase> mlir::createGpuKernelOutliningPass() { 15179f53b0cSJacques Pienaar return std::make_unique<GpuKernelOutliningPass>(); 15260965b46SAlex Zinenko } 15360965b46SAlex Zinenko 15460965b46SAlex Zinenko static PassRegistration<GpuKernelOutliningPass> 15560965b46SAlex Zinenko pass("gpu-kernel-outlining", 15660965b46SAlex Zinenko "Outline gpu.launch bodies to kernel functions."); 157