160965b46SAlex Zinenko //===- KernelOutlining.cpp - Implementation of GPU kernel outling ---------===//
260965b46SAlex Zinenko //
360965b46SAlex Zinenko // Copyright 2019 The MLIR Authors.
460965b46SAlex Zinenko //
560965b46SAlex Zinenko // Licensed under the Apache License, Version 2.0 (the "License");
660965b46SAlex Zinenko // you may not use this file except in compliance with the License.
760965b46SAlex Zinenko // You may obtain a copy of the License at
860965b46SAlex Zinenko //
960965b46SAlex Zinenko //   http://www.apache.org/licenses/LICENSE-2.0
1060965b46SAlex Zinenko //
1160965b46SAlex Zinenko // Unless required by applicable law or agreed to in writing, software
1260965b46SAlex Zinenko // distributed under the License is distributed on an "AS IS" BASIS,
1360965b46SAlex Zinenko // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1460965b46SAlex Zinenko // See the License for the specific language governing permissions and
1560965b46SAlex Zinenko // limitations under the License.
1660965b46SAlex Zinenko // =============================================================================
1760965b46SAlex Zinenko //
1860965b46SAlex Zinenko // This file implements the GPU dialect kernel outlining pass.
1960965b46SAlex Zinenko //
2060965b46SAlex Zinenko //===----------------------------------------------------------------------===//
2160965b46SAlex Zinenko 
2260965b46SAlex Zinenko #include "mlir/Dialect/GPU/GPUDialect.h"
2360965b46SAlex Zinenko #include "mlir/Dialect/GPU/Passes.h"
24ba0fa925SRiver Riddle #include "mlir/Dialect/StandardOps/Ops.h"
2560965b46SAlex Zinenko #include "mlir/IR/BlockAndValueMapping.h"
2660965b46SAlex Zinenko #include "mlir/IR/Builders.h"
2760965b46SAlex Zinenko #include "mlir/Pass/Pass.h"
2860965b46SAlex Zinenko 
2960965b46SAlex Zinenko using namespace mlir;
3060965b46SAlex Zinenko 
3160965b46SAlex Zinenko template <typename OpTy>
3260965b46SAlex Zinenko static void createForAllDimensions(OpBuilder &builder, Location loc,
3360965b46SAlex Zinenko                                    SmallVectorImpl<Value *> &values) {
3460965b46SAlex Zinenko   for (StringRef dim : {"x", "y", "z"}) {
3560965b46SAlex Zinenko     Value *v = builder.create<OpTy>(loc, builder.getIndexType(),
3660965b46SAlex Zinenko                                     builder.getStringAttr(dim));
3760965b46SAlex Zinenko     values.push_back(v);
3860965b46SAlex Zinenko   }
3960965b46SAlex Zinenko }
4060965b46SAlex Zinenko 
4160965b46SAlex Zinenko // Add operations generating block/thread ids and gird/block dimensions at the
4260965b46SAlex Zinenko // beginning of `kernelFunc` and replace uses of the respective function args.
4360965b46SAlex Zinenko static void injectGpuIndexOperations(Location loc, FuncOp kernelFunc) {
4460965b46SAlex Zinenko   OpBuilder OpBuilder(kernelFunc.getBody());
4560965b46SAlex Zinenko   SmallVector<Value *, 12> indexOps;
4660965b46SAlex Zinenko   createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
4760965b46SAlex Zinenko   createForAllDimensions<gpu::ThreadId>(OpBuilder, loc, indexOps);
4860965b46SAlex Zinenko   createForAllDimensions<gpu::GridDim>(OpBuilder, loc, indexOps);
4960965b46SAlex Zinenko   createForAllDimensions<gpu::BlockDim>(OpBuilder, loc, indexOps);
5060965b46SAlex Zinenko   // Replace the leading 12 function args with the respective thread/block index
5160965b46SAlex Zinenko   // operations. Iterate backwards since args are erased and indices change.
5260965b46SAlex Zinenko   for (int i = 11; i >= 0; --i) {
5360965b46SAlex Zinenko     auto &firstBlock = kernelFunc.front();
5460965b46SAlex Zinenko     firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]);
5560965b46SAlex Zinenko     firstBlock.eraseArgument(i);
5660965b46SAlex Zinenko   }
5760965b46SAlex Zinenko }
5860965b46SAlex Zinenko 
59*dfd06af5SStephan Herhut // Move all constant arguments of the given kernel function into the function,
60*dfd06af5SStephan Herhut // thereby reducing the number of kernel arguments.
61*dfd06af5SStephan Herhut static gpu::LaunchFuncOp inlineConstants(FuncOp kernelFunc,
62*dfd06af5SStephan Herhut                                          gpu::LaunchFuncOp launch) {
63*dfd06af5SStephan Herhut   OpBuilder kernelBuilder(kernelFunc.getBody());
64*dfd06af5SStephan Herhut   auto &firstBlock = kernelFunc.getBody().front();
65*dfd06af5SStephan Herhut   llvm::SmallVector<Value *, 8> newLaunchArgs;
66*dfd06af5SStephan Herhut   for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) {
67*dfd06af5SStephan Herhut     auto operandOp = launch.getKernelOperand(i)->getDefiningOp();
68*dfd06af5SStephan Herhut     auto constant = dyn_cast_or_null<ConstantOp>(operandOp);
69*dfd06af5SStephan Herhut     if (!constant) {
70*dfd06af5SStephan Herhut       newLaunchArgs.push_back(launch.getKernelOperand(i));
71*dfd06af5SStephan Herhut       continue;
72*dfd06af5SStephan Herhut     }
73*dfd06af5SStephan Herhut     auto newConstant = kernelBuilder.clone(*operandOp);
74*dfd06af5SStephan Herhut     firstBlock.getArgument(i)->replaceAllUsesWith(newConstant->getResult(0));
75*dfd06af5SStephan Herhut     firstBlock.eraseArgument(i);
76*dfd06af5SStephan Herhut   }
77*dfd06af5SStephan Herhut   if (newLaunchArgs.size() != launch.getNumKernelOperands()) {
78*dfd06af5SStephan Herhut     std::reverse(newLaunchArgs.begin(), newLaunchArgs.end());
79*dfd06af5SStephan Herhut     OpBuilder LaunchBuilder(launch);
80*dfd06af5SStephan Herhut     SmallVector<Type, 8> newArgumentTypes;
81*dfd06af5SStephan Herhut     newArgumentTypes.reserve(firstBlock.getNumArguments());
82*dfd06af5SStephan Herhut     for (auto value : firstBlock.getArguments()) {
83*dfd06af5SStephan Herhut       newArgumentTypes.push_back(value->getType());
84*dfd06af5SStephan Herhut     }
85*dfd06af5SStephan Herhut     kernelFunc.setType(LaunchBuilder.getFunctionType(newArgumentTypes, {}));
86*dfd06af5SStephan Herhut     auto newLaunch = LaunchBuilder.create<gpu::LaunchFuncOp>(
87*dfd06af5SStephan Herhut         launch.getLoc(), kernelFunc, launch.getGridSizeOperandValues(),
88*dfd06af5SStephan Herhut         launch.getBlockSizeOperandValues(), newLaunchArgs);
89*dfd06af5SStephan Herhut     launch.erase();
90*dfd06af5SStephan Herhut     return newLaunch;
91*dfd06af5SStephan Herhut   }
92*dfd06af5SStephan Herhut   return launch;
93*dfd06af5SStephan Herhut }
94*dfd06af5SStephan Herhut 
9560965b46SAlex Zinenko // Outline the `gpu.launch` operation body into a kernel function. Replace
9660965b46SAlex Zinenko // `gpu.return` operations by `std.return` in the generated functions.
9760965b46SAlex Zinenko static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
9860965b46SAlex Zinenko   Location loc = launchOp.getLoc();
9960965b46SAlex Zinenko   SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
10060965b46SAlex Zinenko   FunctionType type =
10160965b46SAlex Zinenko       FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
10260965b46SAlex Zinenko   std::string kernelFuncName =
10360965b46SAlex Zinenko       Twine(launchOp.getParentOfType<FuncOp>().getName(), "_kernel").str();
10460965b46SAlex Zinenko   FuncOp outlinedFunc = FuncOp::create(loc, kernelFuncName, type);
10560965b46SAlex Zinenko   outlinedFunc.getBody().takeBody(launchOp.getBody());
10660965b46SAlex Zinenko   Builder builder(launchOp.getContext());
10760965b46SAlex Zinenko   outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
10860965b46SAlex Zinenko                        builder.getUnitAttr());
10960965b46SAlex Zinenko   injectGpuIndexOperations(loc, outlinedFunc);
1104bfae66dSRiver Riddle   outlinedFunc.walk([](mlir::gpu::Return op) {
11160965b46SAlex Zinenko     OpBuilder replacer(op);
11260965b46SAlex Zinenko     replacer.create<ReturnOp>(op.getLoc());
11360965b46SAlex Zinenko     op.erase();
11460965b46SAlex Zinenko   });
11560965b46SAlex Zinenko   return outlinedFunc;
11660965b46SAlex Zinenko }
11760965b46SAlex Zinenko 
11860965b46SAlex Zinenko // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
119*dfd06af5SStephan Herhut // `kernelFunc`. The kernel func contains the body of the `gpu.launch` with
120*dfd06af5SStephan Herhut // constant region arguments inlined.
12160965b46SAlex Zinenko static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, FuncOp kernelFunc) {
12260965b46SAlex Zinenko   OpBuilder builder(launchOp);
12360965b46SAlex Zinenko   SmallVector<Value *, 4> kernelOperandValues(
12460965b46SAlex Zinenko       launchOp.getKernelOperandValues());
125*dfd06af5SStephan Herhut   auto launchFuncOp = builder.create<gpu::LaunchFuncOp>(
12660965b46SAlex Zinenko       launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
12760965b46SAlex Zinenko       launchOp.getBlockSizeOperandValues(), kernelOperandValues);
128*dfd06af5SStephan Herhut   inlineConstants(kernelFunc, launchFuncOp);
12960965b46SAlex Zinenko   launchOp.erase();
13060965b46SAlex Zinenko }
13160965b46SAlex Zinenko 
13260965b46SAlex Zinenko namespace {
13360965b46SAlex Zinenko 
13460965b46SAlex Zinenko class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
13560965b46SAlex Zinenko public:
13660965b46SAlex Zinenko   void runOnModule() override {
13760965b46SAlex Zinenko     ModuleManager moduleManager(getModule());
13860965b46SAlex Zinenko     for (auto func : getModule().getOps<FuncOp>()) {
1394bfae66dSRiver Riddle       func.walk([&](mlir::gpu::LaunchOp op) {
14060965b46SAlex Zinenko         FuncOp outlinedFunc = outlineKernelFunc(op);
14160965b46SAlex Zinenko         moduleManager.insert(outlinedFunc);
14260965b46SAlex Zinenko         convertToLaunchFuncOp(op, outlinedFunc);
14360965b46SAlex Zinenko       });
14460965b46SAlex Zinenko     }
14560965b46SAlex Zinenko   }
14660965b46SAlex Zinenko };
14760965b46SAlex Zinenko 
14860965b46SAlex Zinenko } // namespace
14960965b46SAlex Zinenko 
150926fb685SMehdi Amini std::unique_ptr<ModulePassBase> mlir::createGpuKernelOutliningPass() {
15179f53b0cSJacques Pienaar   return std::make_unique<GpuKernelOutliningPass>();
15260965b46SAlex Zinenko }
15360965b46SAlex Zinenko 
15460965b46SAlex Zinenko static PassRegistration<GpuKernelOutliningPass>
15560965b46SAlex Zinenko     pass("gpu-kernel-outlining",
15660965b46SAlex Zinenko          "Outline gpu.launch bodies to kernel functions.");
157