1 //===- PatternMatchTest.cpp -----------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/CodeGen/GlobalISel/ConstantFoldingMIRBuilder.h"
11 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
12 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
13 #include "llvm/CodeGen/GlobalISel/Utils.h"
14 #include "llvm/CodeGen/MIRParser/MIRParser.h"
15 #include "llvm/CodeGen/MachineFunction.h"
16 #include "llvm/CodeGen/MachineModuleInfo.h"
17 #include "llvm/CodeGen/TargetFrameLowering.h"
18 #include "llvm/CodeGen/TargetInstrInfo.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/CodeGen/TargetSubtargetInfo.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "llvm/Support/TargetRegistry.h"
23 #include "llvm/Support/TargetSelect.h"
24 #include "llvm/Target/TargetMachine.h"
25 #include "llvm/Target/TargetOptions.h"
26 #include "gtest/gtest.h"
27 
28 using namespace llvm;
29 using namespace MIPatternMatch;
30 
31 namespace {
32 
33 void initLLVM() {
34   InitializeAllTargets();
35   InitializeAllTargetMCs();
36   InitializeAllAsmPrinters();
37   InitializeAllAsmParsers();
38 
39   PassRegistry *Registry = PassRegistry::getPassRegistry();
40   initializeCore(*Registry);
41   initializeCodeGen(*Registry);
42 }
43 
44 /// Create a TargetMachine. As we lack a dedicated always available target for
45 /// unittests, we go for "AArch64".
46 std::unique_ptr<TargetMachine> createTargetMachine() {
47   Triple TargetTriple("aarch64--");
48   std::string Error;
49   const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
50   if (!T)
51     return nullptr;
52 
53   TargetOptions Options;
54   return std::unique_ptr<TargetMachine>(T->createTargetMachine(
55       "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive));
56 }
57 
58 std::unique_ptr<Module> parseMIR(LLVMContext &Context,
59                                  std::unique_ptr<MIRParser> &MIR,
60                                  const TargetMachine &TM, StringRef MIRCode,
61                                  const char *FuncName, MachineModuleInfo &MMI) {
62   SMDiagnostic Diagnostic;
63   std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
64   MIR = createMIRParser(std::move(MBuffer), Context);
65   if (!MIR)
66     return nullptr;
67 
68   std::unique_ptr<Module> M = MIR->parseIRModule();
69   if (!M)
70     return nullptr;
71 
72   M->setDataLayout(TM.createDataLayout());
73 
74   if (MIR->parseMachineFunctions(*M, MMI))
75     return nullptr;
76 
77   return M;
78 }
79 
80 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
81 createDummyModule(LLVMContext &Context, const TargetMachine &TM,
82                   StringRef MIRFunc) {
83   SmallString<512> S;
84   StringRef MIRString = (Twine(R"MIR(
85 ---
86 ...
87 name: func
88 registers:
89   - { id: 0, class: _ }
90   - { id: 1, class: _ }
91   - { id: 2, class: _ }
92   - { id: 3, class: _ }
93 body: |
94   bb.1:
95     %0(s64) = COPY $x0
96     %1(s64) = COPY $x1
97     %2(s64) = COPY $x2
98 )MIR") + Twine(MIRFunc) + Twine("...\n"))
99                             .toNullTerminatedStringRef(S);
100   std::unique_ptr<MIRParser> MIR;
101   auto MMI = make_unique<MachineModuleInfo>(&TM);
102   std::unique_ptr<Module> M =
103       parseMIR(Context, MIR, TM, MIRString, "func", *MMI);
104   return make_pair(std::move(M), std::move(MMI));
105 }
106 
107 static MachineFunction *getMFFromMMI(const Module *M,
108                                      const MachineModuleInfo *MMI) {
109   Function *F = M->getFunction("func");
110   auto *MF = MMI->getMachineFunction(*F);
111   return MF;
112 }
113 
114 static void collectCopies(SmallVectorImpl<unsigned> &Copies,
115                           MachineFunction *MF) {
116   for (auto &MBB : *MF)
117     for (MachineInstr &MI : MBB) {
118       if (MI.getOpcode() == TargetOpcode::COPY)
119         Copies.push_back(MI.getOperand(0).getReg());
120     }
121 }
122 
123 TEST(PatternMatchInstr, MatchIntConstant) {
124   LLVMContext Context;
125   std::unique_ptr<TargetMachine> TM = createTargetMachine();
126   if (!TM)
127     return;
128   auto ModuleMMIPair = createDummyModule(Context, *TM, "");
129   MachineFunction *MF =
130       getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
131   SmallVector<unsigned, 4> Copies;
132   collectCopies(Copies, MF);
133   MachineBasicBlock *EntryMBB = &*MF->begin();
134   MachineIRBuilder B(*MF);
135   MachineRegisterInfo &MRI = MF->getRegInfo();
136   B.setInsertPt(*EntryMBB, EntryMBB->end());
137   auto MIBCst = B.buildConstant(LLT::scalar(64), 42);
138   int64_t Cst;
139   bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst));
140   ASSERT_TRUE(match);
141   ASSERT_EQ(Cst, 42);
142 }
143 
144 TEST(PatternMatchInstr, MatchBinaryOp) {
145   LLVMContext Context;
146   std::unique_ptr<TargetMachine> TM = createTargetMachine();
147   if (!TM)
148     return;
149   auto ModuleMMIPair = createDummyModule(Context, *TM, "");
150   MachineFunction *MF =
151       getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
152   SmallVector<unsigned, 4> Copies;
153   collectCopies(Copies, MF);
154   MachineBasicBlock *EntryMBB = &*MF->begin();
155   MachineIRBuilder B(*MF);
156   MachineRegisterInfo &MRI = MF->getRegInfo();
157   B.setInsertPt(*EntryMBB, EntryMBB->end());
158   LLT s64 = LLT::scalar(64);
159   auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
160   // Test case for no bind.
161   bool match =
162       mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg()));
163   ASSERT_TRUE(match);
164   unsigned Src0, Src1, Src2;
165   match = mi_match(MIBAdd->getOperand(0).getReg(), MRI,
166                    m_GAdd(m_Reg(Src0), m_Reg(Src1)));
167   ASSERT_TRUE(match);
168   ASSERT_EQ(Src0, Copies[0]);
169   ASSERT_EQ(Src1, Copies[1]);
170 
171   // Build MUL(ADD %0, %1), %2
172   auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]);
173 
174   // Try to match MUL.
175   match = mi_match(MIBMul->getOperand(0).getReg(), MRI,
176                    m_GMul(m_Reg(Src0), m_Reg(Src1)));
177   ASSERT_TRUE(match);
178   ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg());
179   ASSERT_EQ(Src1, Copies[2]);
180 
181   // Try to match MUL(ADD)
182   match = mi_match(MIBMul->getOperand(0).getReg(), MRI,
183                    m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2)));
184   ASSERT_TRUE(match);
185   ASSERT_EQ(Src0, Copies[0]);
186   ASSERT_EQ(Src1, Copies[1]);
187   ASSERT_EQ(Src2, Copies[2]);
188 
189   // Test Commutativity.
190   auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42));
191   // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate
192   // commutativity.
193   int64_t Cst;
194   match = mi_match(MIBMul2->getOperand(0).getReg(), MRI,
195                    m_GMul(m_ICst(Cst), m_Reg(Src0)));
196   ASSERT_TRUE(match);
197   ASSERT_EQ(Cst, 42);
198   ASSERT_EQ(Src0, Copies[0]);
199 
200   // Make sure commutative doesn't work with something like SUB.
201   auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42));
202   match = mi_match(MIBSub->getOperand(0).getReg(), MRI,
203                    m_GSub(m_ICst(Cst), m_Reg(Src0)));
204   ASSERT_FALSE(match);
205 
206   auto MIBFMul = B.buildInstr(TargetOpcode::G_FMUL, s64, Copies[0],
207                               B.buildConstant(s64, 42));
208   // Match and test commutativity for FMUL.
209   match = mi_match(MIBFMul->getOperand(0).getReg(), MRI,
210                    m_GFMul(m_ICst(Cst), m_Reg(Src0)));
211   ASSERT_TRUE(match);
212   ASSERT_EQ(Cst, 42);
213   ASSERT_EQ(Src0, Copies[0]);
214 
215   // Build AND %0, %1
216   auto MIBAnd = B.buildAnd(s64, Copies[0], Copies[1]);
217   // Try to match AND.
218   match = mi_match(MIBAnd->getOperand(0).getReg(), MRI,
219                    m_GAnd(m_Reg(Src0), m_Reg(Src1)));
220   ASSERT_TRUE(match);
221   ASSERT_EQ(Src0, Copies[0]);
222   ASSERT_EQ(Src1, Copies[1]);
223 
224   // Build OR %0, %1
225   auto MIBOr = B.buildOr(s64, Copies[0], Copies[1]);
226   // Try to match OR.
227   match = mi_match(MIBOr->getOperand(0).getReg(), MRI,
228                    m_GOr(m_Reg(Src0), m_Reg(Src1)));
229   ASSERT_TRUE(match);
230   ASSERT_EQ(Src0, Copies[0]);
231   ASSERT_EQ(Src1, Copies[1]);
232 
233   // Try to use the FoldableInstructionsBuilder to build binary ops.
234   ConstantFoldingMIRBuilder CFB(B.getState());
235   LLT s32 = LLT::scalar(32);
236   auto MIBCAdd =
237       CFB.buildAdd(s32, CFB.buildConstant(s32, 0), CFB.buildConstant(s32, 1));
238   // This should be a constant now.
239   match = mi_match(MIBCAdd->getOperand(0).getReg(), MRI, m_ICst(Cst));
240   ASSERT_TRUE(match);
241   ASSERT_EQ(Cst, 1);
242   auto MIBCAdd1 =
243       CFB.buildInstr(TargetOpcode::G_ADD, s32, CFB.buildConstant(s32, 0),
244                      CFB.buildConstant(s32, 1));
245   // This should be a constant now.
246   match = mi_match(MIBCAdd1->getOperand(0).getReg(), MRI, m_ICst(Cst));
247   ASSERT_TRUE(match);
248   ASSERT_EQ(Cst, 1);
249 
250   // Try one of the other constructors of MachineIRBuilder to make sure it's
251   // compatible.
252   ConstantFoldingMIRBuilder CFB1(*MF);
253   CFB1.setInsertPt(*EntryMBB, EntryMBB->end());
254   auto MIBCSub =
255       CFB1.buildInstr(TargetOpcode::G_SUB, s32, CFB1.buildConstant(s32, 1),
256                       CFB1.buildConstant(s32, 1));
257   // This should be a constant now.
258   match = mi_match(MIBCSub->getOperand(0).getReg(), MRI, m_ICst(Cst));
259   ASSERT_TRUE(match);
260   ASSERT_EQ(Cst, 0);
261 }
262 
263 TEST(PatternMatchInstr, MatchFPUnaryOp) {
264   LLVMContext Context;
265   std::unique_ptr<TargetMachine> TM = createTargetMachine();
266   if (!TM)
267     return;
268   auto ModuleMMIPair = createDummyModule(Context, *TM, "");
269   MachineFunction *MF =
270       getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
271   SmallVector<unsigned, 4> Copies;
272   collectCopies(Copies, MF);
273   MachineBasicBlock *EntryMBB = &*MF->begin();
274   MachineIRBuilder B(*MF);
275   MachineRegisterInfo &MRI = MF->getRegInfo();
276   B.setInsertPt(*EntryMBB, EntryMBB->end());
277 
278   // Truncate s64 to s32.
279   LLT s32 = LLT::scalar(32);
280   auto Copy0s32 = B.buildFPTrunc(s32, Copies[0]);
281 
282   // Match G_FABS.
283   auto MIBFabs = B.buildInstr(TargetOpcode::G_FABS, s32, Copy0s32);
284   bool match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg()));
285   ASSERT_TRUE(match);
286   unsigned Src;
287   match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg(Src)));
288   ASSERT_TRUE(match);
289   ASSERT_EQ(Src, Copy0s32->getOperand(0).getReg());
290 
291   // Build and match FConstant.
292   auto MIBFCst = B.buildFConstant(s32, .5);
293   const ConstantFP *TmpFP{};
294   match = mi_match(MIBFCst->getOperand(0).getReg(), MRI, m_GFCst(TmpFP));
295   ASSERT_TRUE(match);
296   ASSERT_TRUE(TmpFP);
297   APFloat APF((float).5);
298   auto *CFP = ConstantFP::get(Context, APF);
299   ASSERT_EQ(CFP, TmpFP);
300 
301   // Build double float.
302   LLT s64 = LLT::scalar(64);
303   auto MIBFCst64 = B.buildFConstant(s64, .5);
304   const ConstantFP *TmpFP64{};
305   match = mi_match(MIBFCst64->getOperand(0).getReg(), MRI, m_GFCst(TmpFP64));
306   ASSERT_TRUE(match);
307   ASSERT_TRUE(TmpFP64);
308   APFloat APF64(.5);
309   auto CFP64 = ConstantFP::get(Context, APF64);
310   ASSERT_EQ(CFP64, TmpFP64);
311   ASSERT_NE(TmpFP64, TmpFP);
312 
313   // Build half float.
314   LLT s16 = LLT::scalar(16);
315   auto MIBFCst16 = B.buildFConstant(s16, .5);
316   const ConstantFP *TmpFP16{};
317   match = mi_match(MIBFCst16->getOperand(0).getReg(), MRI, m_GFCst(TmpFP16));
318   ASSERT_TRUE(match);
319   ASSERT_TRUE(TmpFP16);
320   bool Ignored;
321   APFloat APF16(.5);
322   APF16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored);
323   auto CFP16 = ConstantFP::get(Context, APF16);
324   ASSERT_EQ(TmpFP16, CFP16);
325   ASSERT_NE(TmpFP16, TmpFP);
326 }
327 
328 TEST(PatternMatchInstr, MatchExtendsTrunc) {
329   LLVMContext Context;
330   std::unique_ptr<TargetMachine> TM = createTargetMachine();
331   if (!TM)
332     return;
333   auto ModuleMMIPair = createDummyModule(Context, *TM, "");
334   MachineFunction *MF =
335       getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
336   SmallVector<unsigned, 4> Copies;
337   collectCopies(Copies, MF);
338   MachineBasicBlock *EntryMBB = &*MF->begin();
339   MachineIRBuilder B(*MF);
340   MachineRegisterInfo &MRI = MF->getRegInfo();
341   B.setInsertPt(*EntryMBB, EntryMBB->end());
342   LLT s64 = LLT::scalar(64);
343   LLT s32 = LLT::scalar(32);
344 
345   auto MIBTrunc = B.buildTrunc(s32, Copies[0]);
346   auto MIBAExt = B.buildAnyExt(s64, MIBTrunc);
347   auto MIBZExt = B.buildZExt(s64, MIBTrunc);
348   auto MIBSExt = B.buildSExt(s64, MIBTrunc);
349   unsigned Src0;
350   bool match =
351       mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0)));
352   ASSERT_TRUE(match);
353   ASSERT_EQ(Src0, Copies[0]);
354   match =
355       mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0)));
356   ASSERT_TRUE(match);
357   ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
358 
359   match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0)));
360   ASSERT_TRUE(match);
361   ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
362 
363   match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0)));
364   ASSERT_TRUE(match);
365   ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
366 
367   // Match ext(trunc src)
368   match = mi_match(MIBAExt->getOperand(0).getReg(), MRI,
369                    m_GAnyExt(m_GTrunc(m_Reg(Src0))));
370   ASSERT_TRUE(match);
371   ASSERT_EQ(Src0, Copies[0]);
372 
373   match = mi_match(MIBSExt->getOperand(0).getReg(), MRI,
374                    m_GSExt(m_GTrunc(m_Reg(Src0))));
375   ASSERT_TRUE(match);
376   ASSERT_EQ(Src0, Copies[0]);
377 
378   match = mi_match(MIBZExt->getOperand(0).getReg(), MRI,
379                    m_GZExt(m_GTrunc(m_Reg(Src0))));
380   ASSERT_TRUE(match);
381   ASSERT_EQ(Src0, Copies[0]);
382 }
383 
384 TEST(PatternMatchInstr, MatchSpecificType) {
385   LLVMContext Context;
386   std::unique_ptr<TargetMachine> TM = createTargetMachine();
387   if (!TM)
388     return;
389   auto ModuleMMIPair = createDummyModule(Context, *TM, "");
390   MachineFunction *MF =
391       getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
392   SmallVector<unsigned, 4> Copies;
393   collectCopies(Copies, MF);
394   MachineBasicBlock *EntryMBB = &*MF->begin();
395   MachineIRBuilder B(*MF);
396   MachineRegisterInfo &MRI = MF->getRegInfo();
397   B.setInsertPt(*EntryMBB, EntryMBB->end());
398 
399   // Try to match a 64bit add.
400   LLT s64 = LLT::scalar(64);
401   LLT s32 = LLT::scalar(32);
402   auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
403   ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI,
404                         m_GAdd(m_SpecificType(s32), m_Reg())));
405   ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI,
406                        m_GAdd(m_SpecificType(s64), m_Reg())));
407 
408   // Try to match the destination type of a bitcast.
409   LLT v2s32 = LLT::vector(2, 32);
410   auto MIBCast = B.buildCast(v2s32, Copies[0]);
411   ASSERT_TRUE(
412       mi_match(MIBCast->getOperand(0).getReg(), MRI, m_GBitcast(m_Reg())));
413   ASSERT_TRUE(
414       mi_match(MIBCast->getOperand(0).getReg(), MRI, m_SpecificType(v2s32)));
415   ASSERT_TRUE(
416       mi_match(MIBCast->getOperand(1).getReg(), MRI, m_SpecificType(s64)));
417 
418   // Build a PTRToInt and INTTOPTR and match and test them.
419   LLT PtrTy = LLT::pointer(0, 64);
420   auto MIBIntToPtr = B.buildCast(PtrTy, Copies[0]);
421   auto MIBPtrToInt = B.buildCast(s64, MIBIntToPtr);
422   unsigned Src0;
423 
424   // match the ptrtoint(inttoptr reg)
425   bool match = mi_match(MIBPtrToInt->getOperand(0).getReg(), MRI,
426                         m_GPtrToInt(m_GIntToPtr(m_Reg(Src0))));
427   ASSERT_TRUE(match);
428   ASSERT_EQ(Src0, Copies[0]);
429 }
430 
431 TEST(PatternMatchInstr, MatchCombinators) {
432   LLVMContext Context;
433   std::unique_ptr<TargetMachine> TM = createTargetMachine();
434   if (!TM)
435     return;
436   auto ModuleMMIPair = createDummyModule(Context, *TM, "");
437   MachineFunction *MF =
438       getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
439   SmallVector<unsigned, 4> Copies;
440   collectCopies(Copies, MF);
441   MachineBasicBlock *EntryMBB = &*MF->begin();
442   MachineIRBuilder B(*MF);
443   MachineRegisterInfo &MRI = MF->getRegInfo();
444   B.setInsertPt(*EntryMBB, EntryMBB->end());
445   LLT s64 = LLT::scalar(64);
446   LLT s32 = LLT::scalar(32);
447   auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
448   unsigned Src0, Src1;
449   bool match =
450       mi_match(MIBAdd->getOperand(0).getReg(), MRI,
451                m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
452   ASSERT_TRUE(match);
453   ASSERT_EQ(Src0, Copies[0]);
454   ASSERT_EQ(Src1, Copies[1]);
455   // Check for s32 (which should fail).
456   match =
457       mi_match(MIBAdd->getOperand(0).getReg(), MRI,
458                m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
459   ASSERT_FALSE(match);
460   match =
461       mi_match(MIBAdd->getOperand(0).getReg(), MRI,
462                m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
463   ASSERT_TRUE(match);
464   ASSERT_EQ(Src0, Copies[0]);
465   ASSERT_EQ(Src1, Copies[1]);
466 
467   // Match a case where none of the predicates hold true.
468   match = mi_match(
469       MIBAdd->getOperand(0).getReg(), MRI,
470       m_any_of(m_SpecificType(LLT::scalar(16)), m_GSub(m_Reg(), m_Reg())));
471   ASSERT_FALSE(match);
472 }
473 } // namespace
474 
475 int main(int argc, char **argv) {
476   ::testing::InitGoogleTest(&argc, argv);
477   initLLVM();
478   return RUN_ALL_TESTS();
479 }
480