1 //===- AArch64MacroFusion.cpp - AArch64 Macro Fusion ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This file contains the AArch64 implementation of the DAG scheduling
10 /// mutation to pair instructions back to back.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "AArch64MacroFusion.h"
15 #include "AArch64Subtarget.h"
16 #include "llvm/CodeGen/MacroFusion.h"
17 #include "llvm/CodeGen/TargetInstrInfo.h"
18
19 using namespace llvm;
20
21 /// CMN, CMP, TST followed by Bcc
isArithmeticBccPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI,bool CmpOnly)22 static bool isArithmeticBccPair(const MachineInstr *FirstMI,
23 const MachineInstr &SecondMI, bool CmpOnly) {
24 if (SecondMI.getOpcode() != AArch64::Bcc)
25 return false;
26
27 // Assume the 1st instr to be a wildcard if it is unspecified.
28 if (FirstMI == nullptr)
29 return true;
30
31 // If we're in CmpOnly mode, we only fuse arithmetic instructions that
32 // discard their result.
33 if (CmpOnly && !(FirstMI->getOperand(0).getReg() == AArch64::XZR ||
34 FirstMI->getOperand(0).getReg() == AArch64::WZR)) {
35 return false;
36 }
37
38 switch (FirstMI->getOpcode()) {
39 case AArch64::ADDSWri:
40 case AArch64::ADDSWrr:
41 case AArch64::ADDSXri:
42 case AArch64::ADDSXrr:
43 case AArch64::ANDSWri:
44 case AArch64::ANDSWrr:
45 case AArch64::ANDSXri:
46 case AArch64::ANDSXrr:
47 case AArch64::SUBSWri:
48 case AArch64::SUBSWrr:
49 case AArch64::SUBSXri:
50 case AArch64::SUBSXrr:
51 case AArch64::BICSWrr:
52 case AArch64::BICSXrr:
53 return true;
54 case AArch64::ADDSWrs:
55 case AArch64::ADDSXrs:
56 case AArch64::ANDSWrs:
57 case AArch64::ANDSXrs:
58 case AArch64::SUBSWrs:
59 case AArch64::SUBSXrs:
60 case AArch64::BICSWrs:
61 case AArch64::BICSXrs:
62 // Shift value can be 0 making these behave like the "rr" variant...
63 return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
64 }
65
66 return false;
67 }
68
69 /// ALU operations followed by CBZ/CBNZ.
isArithmeticCbzPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)70 static bool isArithmeticCbzPair(const MachineInstr *FirstMI,
71 const MachineInstr &SecondMI) {
72 if (SecondMI.getOpcode() != AArch64::CBZW &&
73 SecondMI.getOpcode() != AArch64::CBZX &&
74 SecondMI.getOpcode() != AArch64::CBNZW &&
75 SecondMI.getOpcode() != AArch64::CBNZX)
76 return false;
77
78 // Assume the 1st instr to be a wildcard if it is unspecified.
79 if (FirstMI == nullptr)
80 return true;
81
82 switch (FirstMI->getOpcode()) {
83 case AArch64::ADDWri:
84 case AArch64::ADDWrr:
85 case AArch64::ADDXri:
86 case AArch64::ADDXrr:
87 case AArch64::ANDWri:
88 case AArch64::ANDWrr:
89 case AArch64::ANDXri:
90 case AArch64::ANDXrr:
91 case AArch64::EORWri:
92 case AArch64::EORWrr:
93 case AArch64::EORXri:
94 case AArch64::EORXrr:
95 case AArch64::ORRWri:
96 case AArch64::ORRWrr:
97 case AArch64::ORRXri:
98 case AArch64::ORRXrr:
99 case AArch64::SUBWri:
100 case AArch64::SUBWrr:
101 case AArch64::SUBXri:
102 case AArch64::SUBXrr:
103 return true;
104 case AArch64::ADDWrs:
105 case AArch64::ADDXrs:
106 case AArch64::ANDWrs:
107 case AArch64::ANDXrs:
108 case AArch64::SUBWrs:
109 case AArch64::SUBXrs:
110 case AArch64::BICWrs:
111 case AArch64::BICXrs:
112 // Shift value can be 0 making these behave like the "rr" variant...
113 return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
114 }
115
116 return false;
117 }
118
119 /// AES crypto encoding or decoding.
isAESPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)120 static bool isAESPair(const MachineInstr *FirstMI,
121 const MachineInstr &SecondMI) {
122 // Assume the 1st instr to be a wildcard if it is unspecified.
123 switch (SecondMI.getOpcode()) {
124 // AES encode.
125 case AArch64::AESMCrr:
126 case AArch64::AESMCrrTied:
127 return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESErr;
128 // AES decode.
129 case AArch64::AESIMCrr:
130 case AArch64::AESIMCrrTied:
131 return FirstMI == nullptr || FirstMI->getOpcode() == AArch64::AESDrr;
132 }
133
134 return false;
135 }
136
137 /// AESE/AESD/PMULL + EOR.
isCryptoEORPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)138 static bool isCryptoEORPair(const MachineInstr *FirstMI,
139 const MachineInstr &SecondMI) {
140 if (SecondMI.getOpcode() != AArch64::EORv16i8)
141 return false;
142
143 // Assume the 1st instr to be a wildcard if it is unspecified.
144 if (FirstMI == nullptr)
145 return true;
146
147 switch (FirstMI->getOpcode()) {
148 case AArch64::AESErr:
149 case AArch64::AESDrr:
150 case AArch64::PMULLv16i8:
151 case AArch64::PMULLv8i8:
152 case AArch64::PMULLv1i64:
153 case AArch64::PMULLv2i64:
154 return true;
155 }
156
157 return false;
158 }
159
isAdrpAddPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)160 static bool isAdrpAddPair(const MachineInstr *FirstMI,
161 const MachineInstr &SecondMI) {
162 // Assume the 1st instr to be a wildcard if it is unspecified.
163 if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::ADRP) &&
164 SecondMI.getOpcode() == AArch64::ADDXri)
165 return true;
166 return false;
167 }
168
169 /// Literal generation.
isLiteralsPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)170 static bool isLiteralsPair(const MachineInstr *FirstMI,
171 const MachineInstr &SecondMI) {
172 // Assume the 1st instr to be a wildcard if it is unspecified.
173 // 32 bit immediate.
174 if ((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZWi) &&
175 (SecondMI.getOpcode() == AArch64::MOVKWi &&
176 SecondMI.getOperand(3).getImm() == 16))
177 return true;
178
179 // Lower half of 64 bit immediate.
180 if((FirstMI == nullptr || FirstMI->getOpcode() == AArch64::MOVZXi) &&
181 (SecondMI.getOpcode() == AArch64::MOVKXi &&
182 SecondMI.getOperand(3).getImm() == 16))
183 return true;
184
185 // Upper half of 64 bit immediate.
186 if ((FirstMI == nullptr ||
187 (FirstMI->getOpcode() == AArch64::MOVKXi &&
188 FirstMI->getOperand(3).getImm() == 32)) &&
189 (SecondMI.getOpcode() == AArch64::MOVKXi &&
190 SecondMI.getOperand(3).getImm() == 48))
191 return true;
192
193 return false;
194 }
195
196 /// Fuse address generation and loads or stores.
isAddressLdStPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)197 static bool isAddressLdStPair(const MachineInstr *FirstMI,
198 const MachineInstr &SecondMI) {
199 switch (SecondMI.getOpcode()) {
200 case AArch64::STRBBui:
201 case AArch64::STRBui:
202 case AArch64::STRDui:
203 case AArch64::STRHHui:
204 case AArch64::STRHui:
205 case AArch64::STRQui:
206 case AArch64::STRSui:
207 case AArch64::STRWui:
208 case AArch64::STRXui:
209 case AArch64::LDRBBui:
210 case AArch64::LDRBui:
211 case AArch64::LDRDui:
212 case AArch64::LDRHHui:
213 case AArch64::LDRHui:
214 case AArch64::LDRQui:
215 case AArch64::LDRSui:
216 case AArch64::LDRWui:
217 case AArch64::LDRXui:
218 case AArch64::LDRSBWui:
219 case AArch64::LDRSBXui:
220 case AArch64::LDRSHWui:
221 case AArch64::LDRSHXui:
222 case AArch64::LDRSWui:
223 // Assume the 1st instr to be a wildcard if it is unspecified.
224 if (FirstMI == nullptr)
225 return true;
226
227 switch (FirstMI->getOpcode()) {
228 case AArch64::ADR:
229 return SecondMI.getOperand(2).getImm() == 0;
230 case AArch64::ADRP:
231 return true;
232 }
233 }
234
235 return false;
236 }
237
238 /// Compare and conditional select.
isCCSelectPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)239 static bool isCCSelectPair(const MachineInstr *FirstMI,
240 const MachineInstr &SecondMI) {
241 // 32 bits
242 if (SecondMI.getOpcode() == AArch64::CSELWr) {
243 // Assume the 1st instr to be a wildcard if it is unspecified.
244 if (FirstMI == nullptr)
245 return true;
246
247 if (FirstMI->definesRegister(AArch64::WZR))
248 switch (FirstMI->getOpcode()) {
249 case AArch64::SUBSWrs:
250 return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
251 case AArch64::SUBSWrx:
252 return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
253 case AArch64::SUBSWrr:
254 case AArch64::SUBSWri:
255 return true;
256 }
257 }
258
259 // 64 bits
260 if (SecondMI.getOpcode() == AArch64::CSELXr) {
261 // Assume the 1st instr to be a wildcard if it is unspecified.
262 if (FirstMI == nullptr)
263 return true;
264
265 if (FirstMI->definesRegister(AArch64::XZR))
266 switch (FirstMI->getOpcode()) {
267 case AArch64::SUBSXrs:
268 return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
269 case AArch64::SUBSXrx:
270 case AArch64::SUBSXrx64:
271 return !AArch64InstrInfo::hasExtendedReg(*FirstMI);
272 case AArch64::SUBSXrr:
273 case AArch64::SUBSXri:
274 return true;
275 }
276 }
277
278 return false;
279 }
280
281 // Arithmetic and logic.
isArithmeticLogicPair(const MachineInstr * FirstMI,const MachineInstr & SecondMI)282 static bool isArithmeticLogicPair(const MachineInstr *FirstMI,
283 const MachineInstr &SecondMI) {
284 if (AArch64InstrInfo::hasShiftedReg(SecondMI))
285 return false;
286
287 switch (SecondMI.getOpcode()) {
288 // Arithmetic
289 case AArch64::ADDWrr:
290 case AArch64::ADDXrr:
291 case AArch64::SUBWrr:
292 case AArch64::SUBXrr:
293 case AArch64::ADDWrs:
294 case AArch64::ADDXrs:
295 case AArch64::SUBWrs:
296 case AArch64::SUBXrs:
297 // Logic
298 case AArch64::ANDWrr:
299 case AArch64::ANDXrr:
300 case AArch64::BICWrr:
301 case AArch64::BICXrr:
302 case AArch64::EONWrr:
303 case AArch64::EONXrr:
304 case AArch64::EORWrr:
305 case AArch64::EORXrr:
306 case AArch64::ORNWrr:
307 case AArch64::ORNXrr:
308 case AArch64::ORRWrr:
309 case AArch64::ORRXrr:
310 case AArch64::ANDWrs:
311 case AArch64::ANDXrs:
312 case AArch64::BICWrs:
313 case AArch64::BICXrs:
314 case AArch64::EONWrs:
315 case AArch64::EONXrs:
316 case AArch64::EORWrs:
317 case AArch64::EORXrs:
318 case AArch64::ORNWrs:
319 case AArch64::ORNXrs:
320 case AArch64::ORRWrs:
321 case AArch64::ORRXrs:
322 // Assume the 1st instr to be a wildcard if it is unspecified.
323 if (FirstMI == nullptr)
324 return true;
325
326 // Arithmetic
327 switch (FirstMI->getOpcode()) {
328 case AArch64::ADDWrr:
329 case AArch64::ADDXrr:
330 case AArch64::ADDSWrr:
331 case AArch64::ADDSXrr:
332 case AArch64::SUBWrr:
333 case AArch64::SUBXrr:
334 case AArch64::SUBSWrr:
335 case AArch64::SUBSXrr:
336 return true;
337 case AArch64::ADDWrs:
338 case AArch64::ADDXrs:
339 case AArch64::ADDSWrs:
340 case AArch64::ADDSXrs:
341 case AArch64::SUBWrs:
342 case AArch64::SUBXrs:
343 case AArch64::SUBSWrs:
344 case AArch64::SUBSXrs:
345 return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
346 }
347 break;
348
349 // Arithmetic, setting flags.
350 case AArch64::ADDSWrr:
351 case AArch64::ADDSXrr:
352 case AArch64::SUBSWrr:
353 case AArch64::SUBSXrr:
354 case AArch64::ADDSWrs:
355 case AArch64::ADDSXrs:
356 case AArch64::SUBSWrs:
357 case AArch64::SUBSXrs:
358 // Assume the 1st instr to be a wildcard if it is unspecified.
359 if (FirstMI == nullptr)
360 return true;
361
362 // Arithmetic, not setting flags.
363 switch (FirstMI->getOpcode()) {
364 case AArch64::ADDWrr:
365 case AArch64::ADDXrr:
366 case AArch64::SUBWrr:
367 case AArch64::SUBXrr:
368 return true;
369 case AArch64::ADDWrs:
370 case AArch64::ADDXrs:
371 case AArch64::SUBWrs:
372 case AArch64::SUBXrs:
373 return !AArch64InstrInfo::hasShiftedReg(*FirstMI);
374 }
375 break;
376 }
377
378 return false;
379 }
380
381 /// \brief Check if the instr pair, FirstMI and SecondMI, should be fused
382 /// together. Given SecondMI, when FirstMI is unspecified, then check if
383 /// SecondMI may be part of a fused pair at all.
shouldScheduleAdjacent(const TargetInstrInfo & TII,const TargetSubtargetInfo & TSI,const MachineInstr * FirstMI,const MachineInstr & SecondMI)384 static bool shouldScheduleAdjacent(const TargetInstrInfo &TII,
385 const TargetSubtargetInfo &TSI,
386 const MachineInstr *FirstMI,
387 const MachineInstr &SecondMI) {
388 const AArch64Subtarget &ST = static_cast<const AArch64Subtarget&>(TSI);
389
390 // All checking functions assume that the 1st instr is a wildcard if it is
391 // unspecified.
392 if (ST.hasCmpBccFusion() || ST.hasArithmeticBccFusion()) {
393 bool CmpOnly = !ST.hasArithmeticBccFusion();
394 if (isArithmeticBccPair(FirstMI, SecondMI, CmpOnly))
395 return true;
396 }
397 if (ST.hasArithmeticCbzFusion() && isArithmeticCbzPair(FirstMI, SecondMI))
398 return true;
399 if (ST.hasFuseAES() && isAESPair(FirstMI, SecondMI))
400 return true;
401 if (ST.hasFuseCryptoEOR() && isCryptoEORPair(FirstMI, SecondMI))
402 return true;
403 if (ST.hasFuseAdrpAdd() && isAdrpAddPair(FirstMI, SecondMI))
404 return true;
405 if (ST.hasFuseLiterals() && isLiteralsPair(FirstMI, SecondMI))
406 return true;
407 if (ST.hasFuseAddress() && isAddressLdStPair(FirstMI, SecondMI))
408 return true;
409 if (ST.hasFuseCCSelect() && isCCSelectPair(FirstMI, SecondMI))
410 return true;
411 if (ST.hasFuseArithmeticLogic() && isArithmeticLogicPair(FirstMI, SecondMI))
412 return true;
413
414 return false;
415 }
416
417 std::unique_ptr<ScheduleDAGMutation>
createAArch64MacroFusionDAGMutation()418 llvm::createAArch64MacroFusionDAGMutation() {
419 return createMacroFusionDAGMutation(shouldScheduleAdjacent);
420 }
421