1 //===- TestAvailability.cpp - Test pass for setting Entry point ABI info --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass that sets the spv.entry_point_abi attribute on 10 // functions that are to be lowered as entry point functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 17 #include "mlir/Pass/Pass.h" 18 19 using namespace mlir; 20 21 namespace { 22 /// Pass to set the spv.entry_point_abi 23 struct TestSpirvEntryPointABIPass 24 : public PassWrapper<TestSpirvEntryPointABIPass, 25 OperationPass<gpu::GPUModuleOp>> { 26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSpirvEntryPointABIPass) 27 28 StringRef getArgument() const final { return "test-spirv-entry-point-abi"; } 29 StringRef getDescription() const final { 30 return "Set the spv.entry_point_abi attribute on GPU kernel function " 31 "within the " 32 "module, intended for testing only"; 33 } 34 void getDependentDialects(DialectRegistry ®istry) const override { 35 registry.insert<spirv::SPIRVDialect>(); 36 } 37 TestSpirvEntryPointABIPass() = default; 38 TestSpirvEntryPointABIPass(const TestSpirvEntryPointABIPass &) {} 39 void runOnOperation() override; 40 41 private: 42 Pass::ListOption<int32_t> workgroupSize{ 43 *this, "workgroup-size", 44 llvm::cl::desc( 45 "Workgroup size to use for all gpu.func kernels in the module, " 46 "specified with x-dimension first, y-dimension next and z-dimension " 47 "last. Unspecified dimensions will be set to 1"), 48 llvm::cl::ZeroOrMore}; 49 }; 50 } // namespace 51 52 void TestSpirvEntryPointABIPass::runOnOperation() { 53 gpu::GPUModuleOp gpuModule = getOperation(); 54 MLIRContext *context = &getContext(); 55 StringRef attrName = spirv::getEntryPointABIAttrName(); 56 for (gpu::GPUFuncOp gpuFunc : gpuModule.getOps<gpu::GPUFuncOp>()) { 57 if (!gpu::GPUDialect::isKernel(gpuFunc) || gpuFunc->getAttr(attrName)) 58 continue; 59 SmallVector<int32_t, 3> workgroupSizeVec(workgroupSize.begin(), 60 workgroupSize.end()); 61 workgroupSizeVec.resize(3, 1); 62 gpuFunc->setAttr(attrName, 63 spirv::getEntryPointABIAttr(workgroupSizeVec, context)); 64 } 65 } 66 67 namespace mlir { 68 void registerTestSpirvEntryPointABIPass() { 69 PassRegistration<TestSpirvEntryPointABIPass>(); 70 } 71 } // namespace mlir 72