1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19 #include "llvm/CodeGen/MachineInstr.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52
isTypeFoldingSupported(unsigned Opcode)53 bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56
SPIRVLegalizerInfo(const SPIRVSubtarget & ST)57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 = LLT::pointer(5, PSize); // Input
106
107 // TODO: remove copy-pasting here by using concatenation in some way.
108 auto allPtrsScalarsAndVectors = {
109 p0, p1, p2, p3, p4, p5, s1, s8, s16,
110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113
114 auto allScalarsAndVectors = {
115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118
119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123
124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125
126 auto allIntScalars = {s8, s16, s32, s64};
127
128 auto allFloatScalarsAndVectors = {
129 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
130 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131
132 auto allFloatAndIntScalars = allIntScalars;
133
134 auto allPtrs = {p0, p1, p2, p3, p4, p5};
135 auto allWritablePtrs = {p0, p1, p3, p4};
136
137 for (auto Opc : TypeFoldingSupportingOpcs)
138 getActionDefinitionsBuilder(Opc).custom();
139
140 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141
142 // TODO: add proper rules for vectors legalization.
143 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144
145 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147
148 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
149 .legalForCartesianProduct(allPtrs, allPtrs);
150
151 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
152
153 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
154
155 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
156
157 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
158 .legalForCartesianProduct(allIntScalarsAndVectors,
159 allFloatScalarsAndVectors);
160
161 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
162 .legalForCartesianProduct(allFloatScalarsAndVectors,
163 allScalarsAndVectors);
164
165 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
166 .legalFor(allIntScalarsAndVectors);
167
168 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
169 allIntScalarsAndVectors, allIntScalarsAndVectors);
170
171 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
172
173 getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
174 typeInSet(0, allPtrsScalarsAndVectors),
175 typeInSet(1, allPtrsScalarsAndVectors),
176 LegalityPredicate(([=](const LegalityQuery &Query) {
177 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
178 }))));
179
180 getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
181
182 getActionDefinitionsBuilder(G_INTTOPTR)
183 .legalForCartesianProduct(allPtrs, allIntScalars);
184 getActionDefinitionsBuilder(G_PTRTOINT)
185 .legalForCartesianProduct(allIntScalars, allPtrs);
186 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
187 allPtrs, allIntScalars);
188
189 // ST.canDirectlyComparePointers() for pointer args is supported in
190 // legalizeCustom().
191 getActionDefinitionsBuilder(G_ICMP).customIf(
192 all(typeInSet(0, allBoolScalarsAndVectors),
193 typeInSet(1, allPtrsScalarsAndVectors)));
194
195 getActionDefinitionsBuilder(G_FCMP).legalIf(
196 all(typeInSet(0, allBoolScalarsAndVectors),
197 typeInSet(1, allFloatScalarsAndVectors)));
198
199 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
200 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
201 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
202 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
203 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
204
205 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
206 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
207
208 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
209 // TODO: add proper legalization rules.
210 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
211
212 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
213 .alwaysLegal();
214
215 // Extensions.
216 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
217 .legalForCartesianProduct(allScalarsAndVectors);
218
219 // FP conversions.
220 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
221 .legalForCartesianProduct(allFloatScalarsAndVectors);
222
223 // Pointer-handling.
224 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
225
226 // Control-flow.
227 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1});
228
229 getActionDefinitionsBuilder({G_FPOW,
230 G_FEXP,
231 G_FEXP2,
232 G_FLOG,
233 G_FLOG2,
234 G_FABS,
235 G_FMINNUM,
236 G_FMAXNUM,
237 G_FCEIL,
238 G_FCOS,
239 G_FSIN,
240 G_FSQRT,
241 G_FFLOOR,
242 G_FRINT,
243 G_FNEARBYINT,
244 G_INTRINSIC_ROUND,
245 G_INTRINSIC_TRUNC,
246 G_FMINIMUM,
247 G_FMAXIMUM,
248 G_INTRINSIC_ROUNDEVEN})
249 .legalFor(allFloatScalarsAndVectors);
250
251 getActionDefinitionsBuilder(G_FCOPYSIGN)
252 .legalForCartesianProduct(allFloatScalarsAndVectors,
253 allFloatScalarsAndVectors);
254
255 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
256 allFloatScalarsAndVectors, allIntScalarsAndVectors);
257
258 getLegacyLegalizerInfo().computeTables();
259 verify(*ST.getInstrInfo());
260 }
261
convertPtrToInt(Register Reg,LLT ConvTy,SPIRVType * SpirvType,LegalizerHelper & Helper,MachineRegisterInfo & MRI,SPIRVGlobalRegistry * GR)262 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
263 LegalizerHelper &Helper,
264 MachineRegisterInfo &MRI,
265 SPIRVGlobalRegistry *GR) {
266 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
267 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
268 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
269 .addDef(ConvReg)
270 .addUse(Reg);
271 return ConvReg;
272 }
273
legalizeCustom(LegalizerHelper & Helper,MachineInstr & MI) const274 bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
275 MachineInstr &MI) const {
276 auto Opc = MI.getOpcode();
277 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
278 if (!isTypeFoldingSupported(Opc)) {
279 assert(Opc == TargetOpcode::G_ICMP);
280 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
281 auto &Op0 = MI.getOperand(2);
282 auto &Op1 = MI.getOperand(3);
283 Register Reg0 = Op0.getReg();
284 Register Reg1 = Op1.getReg();
285 CmpInst::Predicate Cond =
286 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
287 if ((!ST->canDirectlyComparePointers() ||
288 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
289 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
290 LLT ConvT = LLT::scalar(ST->getPointerSize());
291 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
292 ST->getPointerSize());
293 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
294 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
295 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
296 }
297 return true;
298 }
299 // TODO: implement legalization for other opcodes.
300 return true;
301 }
302