1 //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/FunctionInterfaces.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/SymbolTable.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // TargetEnv
20 //===----------------------------------------------------------------------===//
21 
TargetEnv(spirv::TargetEnvAttr targetAttr)22 spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
23     : targetAttr(targetAttr) {
24   for (spirv::Extension ext : targetAttr.getExtensions())
25     givenExtensions.insert(ext);
26 
27   // Add extensions implied by the current version.
28   for (spirv::Extension ext :
29        spirv::getImpliedExtensions(targetAttr.getVersion()))
30     givenExtensions.insert(ext);
31 
32   for (spirv::Capability cap : targetAttr.getCapabilities()) {
33     givenCapabilities.insert(cap);
34 
35     // Add capabilities implied by the current capability.
36     for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
37       givenCapabilities.insert(c);
38   }
39 }
40 
getVersion() const41 spirv::Version spirv::TargetEnv::getVersion() const {
42   return targetAttr.getVersion();
43 }
44 
allows(spirv::Capability capability) const45 bool spirv::TargetEnv::allows(spirv::Capability capability) const {
46   return givenCapabilities.count(capability);
47 }
48 
49 Optional<spirv::Capability>
allows(ArrayRef<spirv::Capability> caps) const50 spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
51   const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
52     return givenCapabilities.count(cap);
53   });
54   if (chosen != caps.end())
55     return *chosen;
56   return llvm::None;
57 }
58 
allows(spirv::Extension extension) const59 bool spirv::TargetEnv::allows(spirv::Extension extension) const {
60   return givenExtensions.count(extension);
61 }
62 
63 Optional<spirv::Extension>
allows(ArrayRef<spirv::Extension> exts) const64 spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
65   const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
66     return givenExtensions.count(ext);
67   });
68   if (chosen != exts.end())
69     return *chosen;
70   return llvm::None;
71 }
72 
getVendorID() const73 spirv::Vendor spirv::TargetEnv::getVendorID() const {
74   return targetAttr.getVendorID();
75 }
76 
getDeviceType() const77 spirv::DeviceType spirv::TargetEnv::getDeviceType() const {
78   return targetAttr.getDeviceType();
79 }
80 
getDeviceID() const81 uint32_t spirv::TargetEnv::getDeviceID() const {
82   return targetAttr.getDeviceID();
83 }
84 
getResourceLimits() const85 spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const {
86   return targetAttr.getResourceLimits();
87 }
88 
getContext() const89 MLIRContext *spirv::TargetEnv::getContext() const {
90   return targetAttr.getContext();
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Utility functions
95 //===----------------------------------------------------------------------===//
96 
getInterfaceVarABIAttrName()97 StringRef spirv::getInterfaceVarABIAttrName() {
98   return "spv.interface_var_abi";
99 }
100 
101 spirv::InterfaceVarABIAttr
getInterfaceVarABIAttr(unsigned descriptorSet,unsigned binding,Optional<spirv::StorageClass> storageClass,MLIRContext * context)102 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
103                               Optional<spirv::StorageClass> storageClass,
104                               MLIRContext *context) {
105   return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
106                                          context);
107 }
108 
needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr)109 bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
110   for (spirv::Capability cap : targetAttr.getCapabilities()) {
111     if (cap == spirv::Capability::Kernel)
112       return false;
113     if (cap == spirv::Capability::Shader)
114       return true;
115   }
116   return false;
117 }
118 
getEntryPointABIAttrName()119 StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }
120 
121 spirv::EntryPointABIAttr
getEntryPointABIAttr(ArrayRef<int32_t> localSize,MLIRContext * context)122 spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
123   if (localSize.empty())
124     return spirv::EntryPointABIAttr::get(context, nullptr);
125 
126   assert(localSize.size() == 3);
127   return spirv::EntryPointABIAttr::get(
128       context, DenseElementsAttr::get<int32_t>(
129                    VectorType::get(3, IntegerType::get(context, 32)), localSize)
130                    .cast<DenseIntElementsAttr>());
131 }
132 
lookupEntryPointABI(Operation * op)133 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
134   while (op && !isa<FunctionOpInterface>(op))
135     op = op->getParentOp();
136   if (!op)
137     return {};
138 
139   if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
140           spirv::getEntryPointABIAttrName()))
141     return attr;
142 
143   return {};
144 }
145 
lookupLocalWorkGroupSize(Operation * op)146 DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
147   if (auto entryPoint = spirv::lookupEntryPointABI(op))
148     return entryPoint.getLocalSize();
149 
150   return {};
151 }
152 
153 spirv::ResourceLimitsAttr
getDefaultResourceLimits(MLIRContext * context)154 spirv::getDefaultResourceLimits(MLIRContext *context) {
155   // All the fields have default values. Here we just provide a nicer way to
156   // construct a default resource limit attribute.
157   Builder b(context);
158   return spirv::ResourceLimitsAttr::get(
159       context,
160       /*max_compute_shared_memory_size=*/16384,
161       /*max_compute_workgroup_invocations=*/128,
162       /*max_compute_workgroup_size=*/b.getI32ArrayAttr({128, 128, 64}),
163       /*subgroup_size=*/32,
164       /*cooperative_matrix_properties_nv=*/ArrayAttr());
165 }
166 
getTargetEnvAttrName()167 StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; }
168 
getDefaultTargetEnv(MLIRContext * context)169 spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
170   auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
171                                           {spirv::Capability::Shader},
172                                           ArrayRef<Extension>(), context);
173   return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
174                                    spirv::DeviceType::Unknown,
175                                    spirv::TargetEnvAttr::kUnknownDeviceID,
176                                    spirv::getDefaultResourceLimits(context));
177 }
178 
lookupTargetEnv(Operation * op)179 spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
180   while (op) {
181     op = SymbolTable::getNearestSymbolTable(op);
182     if (!op)
183       break;
184 
185     if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
186             spirv::getTargetEnvAttrName()))
187       return attr;
188 
189     op = op->getParentOp();
190   }
191 
192   return {};
193 }
194 
lookupTargetEnvOrDefault(Operation * op)195 spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
196   if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
197     return attr;
198 
199   return getDefaultTargetEnv(op->getContext());
200 }
201 
202 spirv::AddressingModel
getAddressingModel(spirv::TargetEnvAttr targetAttr)203 spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr) {
204   for (spirv::Capability cap : targetAttr.getCapabilities()) {
205     // TODO: Physical64 is hard-coded here, but some information should come
206     // from TargetEnvAttr to selected between Physical32 and Physical64.
207     if (cap == Capability::Kernel)
208       return spirv::AddressingModel::Physical64;
209     // TODO PhysicalStorageBuffer64 is hard-coded here, but some information
210     // should come from TargetEnvAttr to select between PhysicalStorageBuffer64
211     // and PhysicalStorageBuffer64EXT
212     if (cap == Capability::PhysicalStorageBufferAddresses)
213       return spirv::AddressingModel::PhysicalStorageBuffer64;
214   }
215   // Logical addressing doesn't need any capabilities so return it as default.
216   return spirv::AddressingModel::Logical;
217 }
218 
219 FailureOr<spirv::ExecutionModel>
getExecutionModel(spirv::TargetEnvAttr targetAttr)220 spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) {
221   for (spirv::Capability cap : targetAttr.getCapabilities()) {
222     if (cap == spirv::Capability::Kernel)
223       return spirv::ExecutionModel::Kernel;
224     if (cap == spirv::Capability::Shader)
225       return spirv::ExecutionModel::GLCompute;
226   }
227   return failure();
228 }
229 
230 FailureOr<spirv::MemoryModel>
getMemoryModel(spirv::TargetEnvAttr targetAttr)231 spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) {
232   for (spirv::Capability cap : targetAttr.getCapabilities()) {
233     if (cap == spirv::Capability::Addresses)
234       return spirv::MemoryModel::OpenCL;
235     if (cap == spirv::Capability::Shader)
236       return spirv::MemoryModel::GLSL450;
237   }
238   return failure();
239 }
240