1 //===- TestAvailability.cpp - Test pass for setting Entry point ABI info --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass that sets the spv.entry_point_abi attribute on 10 // functions that are to be lowered as entry point functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 21 namespace { 22 /// Pass to set the spv.entry_point_abi 23 struct TestSpirvEntryPointABIPass 24 : public PassWrapper<TestSpirvEntryPointABIPass, 25 OperationPass<gpu::GPUModuleOp>> { 26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSpirvEntryPointABIPass) 27 28 StringRef getArgument() const final { return "test-spirv-entry-point-abi"; } 29 StringRef getDescription() const final { 30 return "Set the spv.entry_point_abi attribute on GPU kernel function " 31 "within the " 32 "module, intended for testing only"; 33 } 34 void getDependentDialects(DialectRegistry ®istry) const override { 35 registry.insert<spirv::SPIRVDialect>(); 36 } 37 TestSpirvEntryPointABIPass() = default; 38 TestSpirvEntryPointABIPass(const TestSpirvEntryPointABIPass &) {} 39 void runOnOperation() override; 40 41 private: 42 Pass::ListOption<int32_t> workgroupSize{ 43 *this, "workgroup-size", 44 llvm::cl::desc( 45 "Workgroup size to use for all gpu.func kernels in the module, " 46 "specified with x-dimension first, y-dimension next and z-dimension " 47 "last. Unspecified dimensions will be set to 1")}; 48 }; 49 } // namespace 50 51 void TestSpirvEntryPointABIPass::runOnOperation() { 52 gpu::GPUModuleOp gpuModule = getOperation(); 53 MLIRContext *context = &getContext(); 54 StringRef attrName = spirv::getEntryPointABIAttrName(); 55 for (gpu::GPUFuncOp gpuFunc : gpuModule.getOps<gpu::GPUFuncOp>()) { 56 if (!gpu::GPUDialect::isKernel(gpuFunc) || gpuFunc->getAttr(attrName)) 57 continue; 58 SmallVector<int32_t, 3> workgroupSizeVec(workgroupSize.begin(), 59 workgroupSize.end()); 60 workgroupSizeVec.resize(3, 1); 61 gpuFunc->setAttr(attrName, 62 spirv::getEntryPointABIAttr(workgroupSizeVec, context)); 63 } 64 } 65 66 namespace mlir { 67 void registerTestSpirvEntryPointABIPass() { 68 PassRegistration<TestSpirvEntryPointABIPass>(); 69 } 70 } // namespace mlir 71