1 //===- PatternMatchTest.cpp -----------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 11 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 12 #include "llvm/CodeGen/GlobalISel/Utils.h" 13 #include "llvm/CodeGen/MIRParser/MIRParser.h" 14 #include "llvm/CodeGen/MachineFunction.h" 15 #include "llvm/CodeGen/MachineModuleInfo.h" 16 #include "llvm/CodeGen/TargetFrameLowering.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 #include "llvm/CodeGen/TargetLowering.h" 19 #include "llvm/CodeGen/TargetSubtargetInfo.h" 20 #include "llvm/Support/SourceMgr.h" 21 #include "llvm/Support/TargetRegistry.h" 22 #include "llvm/Support/TargetSelect.h" 23 #include "llvm/Target/TargetMachine.h" 24 #include "llvm/Target/TargetOptions.h" 25 #include "gtest/gtest.h" 26 27 using namespace llvm; 28 using namespace MIPatternMatch; 29 30 namespace { 31 32 void initLLVM() { 33 InitializeAllTargets(); 34 InitializeAllTargetMCs(); 35 InitializeAllAsmPrinters(); 36 InitializeAllAsmParsers(); 37 38 PassRegistry *Registry = PassRegistry::getPassRegistry(); 39 initializeCore(*Registry); 40 initializeCodeGen(*Registry); 41 } 42 43 /// Create a TargetMachine. As we lack a dedicated always available target for 44 /// unittests, we go for "AArch64". 45 std::unique_ptr<TargetMachine> createTargetMachine() { 46 Triple TargetTriple("aarch64--"); 47 std::string Error; 48 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); 49 if (!T) 50 return nullptr; 51 52 TargetOptions Options; 53 return std::unique_ptr<TargetMachine>(T->createTargetMachine( 54 "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive)); 55 } 56 57 std::unique_ptr<Module> parseMIR(LLVMContext &Context, 58 std::unique_ptr<MIRParser> &MIR, 59 const TargetMachine &TM, StringRef MIRCode, 60 const char *FuncName, MachineModuleInfo &MMI) { 61 SMDiagnostic Diagnostic; 62 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode); 63 MIR = createMIRParser(std::move(MBuffer), Context); 64 if (!MIR) 65 return nullptr; 66 67 std::unique_ptr<Module> M = MIR->parseIRModule(); 68 if (!M) 69 return nullptr; 70 71 M->setDataLayout(TM.createDataLayout()); 72 73 if (MIR->parseMachineFunctions(*M, MMI)) 74 return nullptr; 75 76 return M; 77 } 78 79 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>> 80 createDummyModule(LLVMContext &Context, const TargetMachine &TM, 81 StringRef MIRFunc) { 82 SmallString<512> S; 83 StringRef MIRString = (Twine(R"MIR( 84 --- 85 ... 86 name: func 87 registers: 88 - { id: 0, class: _ } 89 - { id: 1, class: _ } 90 - { id: 2, class: _ } 91 - { id: 3, class: _ } 92 body: | 93 bb.1: 94 %0(s64) = COPY %x0 95 %1(s64) = COPY %x1 96 %2(s64) = COPY %x2 97 )MIR") + Twine(MIRFunc) + Twine("...\n")) 98 .toNullTerminatedStringRef(S); 99 std::unique_ptr<MIRParser> MIR; 100 auto MMI = make_unique<MachineModuleInfo>(&TM); 101 std::unique_ptr<Module> M = 102 parseMIR(Context, MIR, TM, MIRString, "func", *MMI); 103 return make_pair(std::move(M), std::move(MMI)); 104 } 105 106 static MachineFunction *getMFFromMMI(const Module *M, 107 const MachineModuleInfo *MMI) { 108 Function *F = M->getFunction("func"); 109 auto *MF = MMI->getMachineFunction(*F); 110 return MF; 111 } 112 113 static void collectCopies(SmallVectorImpl<unsigned> &Copies, 114 MachineFunction *MF) { 115 for (auto &MBB : *MF) 116 for (MachineInstr &MI : MBB) { 117 if (MI.getOpcode() == TargetOpcode::COPY) 118 Copies.push_back(MI.getOperand(0).getReg()); 119 } 120 } 121 122 TEST(PatternMatchInstr, MatchIntConstant) { 123 LLVMContext Context; 124 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 125 if (!TM) 126 return; 127 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 128 MachineFunction *MF = 129 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 130 SmallVector<unsigned, 4> Copies; 131 collectCopies(Copies, MF); 132 MachineBasicBlock *EntryMBB = &*MF->begin(); 133 MachineIRBuilder B(*MF); 134 MachineRegisterInfo &MRI = MF->getRegInfo(); 135 B.setInsertPt(*EntryMBB, EntryMBB->end()); 136 auto MIBCst = B.buildConstant(LLT::scalar(64), 42); 137 uint64_t Cst; 138 bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst)); 139 ASSERT_TRUE(match); 140 ASSERT_EQ(Cst, (uint64_t)42); 141 } 142 143 TEST(PatternMatchInstr, MatchBinaryOp) { 144 LLVMContext Context; 145 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 146 if (!TM) 147 return; 148 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 149 MachineFunction *MF = 150 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 151 SmallVector<unsigned, 4> Copies; 152 collectCopies(Copies, MF); 153 MachineBasicBlock *EntryMBB = &*MF->begin(); 154 MachineIRBuilder B(*MF); 155 MachineRegisterInfo &MRI = MF->getRegInfo(); 156 B.setInsertPt(*EntryMBB, EntryMBB->end()); 157 LLT s64 = LLT::scalar(64); 158 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 159 // Test case for no bind. 160 bool match = 161 mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg())); 162 ASSERT_TRUE(match); 163 unsigned Src0, Src1, Src2; 164 match = mi_match(MIBAdd->getOperand(0).getReg(), MRI, 165 m_GAdd(m_Reg(Src0), m_Reg(Src1))); 166 ASSERT_TRUE(match); 167 ASSERT_EQ(Src0, Copies[0]); 168 ASSERT_EQ(Src1, Copies[1]); 169 170 // Build MUL(ADD %0, %1), %2 171 auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]); 172 173 // Try to match MUL. 174 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 175 m_GMul(m_Reg(Src0), m_Reg(Src1))); 176 ASSERT_TRUE(match); 177 ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg()); 178 ASSERT_EQ(Src1, Copies[2]); 179 180 // Try to match MUL(ADD) 181 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 182 m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2))); 183 ASSERT_TRUE(match); 184 ASSERT_EQ(Src0, Copies[0]); 185 ASSERT_EQ(Src1, Copies[1]); 186 ASSERT_EQ(Src2, Copies[2]); 187 188 // Test Commutativity. 189 auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42)); 190 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate 191 // commutativity. 192 uint64_t Cst; 193 match = mi_match(MIBMul2->getOperand(0).getReg(), MRI, 194 m_GMul(m_ICst(Cst), m_Reg(Src0))); 195 ASSERT_TRUE(match); 196 ASSERT_EQ(Cst, (uint64_t)42); 197 ASSERT_EQ(Src0, Copies[0]); 198 199 // Make sure commutative doesn't work with something like SUB. 200 auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42)); 201 match = mi_match(MIBSub->getOperand(0).getReg(), MRI, 202 m_GSub(m_ICst(Cst), m_Reg(Src0))); 203 ASSERT_FALSE(match); 204 } 205 206 TEST(PatternMatchInstr, MatchExtendsTrunc) { 207 LLVMContext Context; 208 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 209 if (!TM) 210 return; 211 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 212 MachineFunction *MF = 213 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 214 SmallVector<unsigned, 4> Copies; 215 collectCopies(Copies, MF); 216 MachineBasicBlock *EntryMBB = &*MF->begin(); 217 MachineIRBuilder B(*MF); 218 MachineRegisterInfo &MRI = MF->getRegInfo(); 219 B.setInsertPt(*EntryMBB, EntryMBB->end()); 220 LLT s64 = LLT::scalar(64); 221 LLT s32 = LLT::scalar(32); 222 223 auto MIBTrunc = B.buildTrunc(s32, Copies[0]); 224 auto MIBAExt = B.buildAnyExt(s64, MIBTrunc); 225 auto MIBZExt = B.buildZExt(s64, MIBTrunc); 226 auto MIBSExt = B.buildSExt(s64, MIBTrunc); 227 unsigned Src0; 228 bool match = 229 mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0))); 230 ASSERT_TRUE(match); 231 ASSERT_EQ(Src0, Copies[0]); 232 match = 233 mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0))); 234 ASSERT_TRUE(match); 235 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 236 237 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0))); 238 ASSERT_TRUE(match); 239 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 240 241 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0))); 242 ASSERT_TRUE(match); 243 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 244 245 // Match ext(trunc src) 246 match = mi_match(MIBAExt->getOperand(0).getReg(), MRI, 247 m_GAnyExt(m_GTrunc(m_Reg(Src0)))); 248 ASSERT_TRUE(match); 249 ASSERT_EQ(Src0, Copies[0]); 250 251 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, 252 m_GSExt(m_GTrunc(m_Reg(Src0)))); 253 ASSERT_TRUE(match); 254 ASSERT_EQ(Src0, Copies[0]); 255 256 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, 257 m_GZExt(m_GTrunc(m_Reg(Src0)))); 258 ASSERT_TRUE(match); 259 ASSERT_EQ(Src0, Copies[0]); 260 } 261 262 TEST(PatternMatchInstr, MatchSpecificType) { 263 LLVMContext Context; 264 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 265 if (!TM) 266 return; 267 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 268 MachineFunction *MF = 269 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 270 SmallVector<unsigned, 4> Copies; 271 collectCopies(Copies, MF); 272 MachineBasicBlock *EntryMBB = &*MF->begin(); 273 MachineIRBuilder B(*MF); 274 MachineRegisterInfo &MRI = MF->getRegInfo(); 275 B.setInsertPt(*EntryMBB, EntryMBB->end()); 276 LLT s64 = LLT::scalar(64); 277 LLT s32 = LLT::scalar(32); 278 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 279 280 // Try to match a 64bit add. 281 ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 282 m_GAdd(m_SpecificType(s32), m_Reg()))); 283 ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 284 m_GAdd(m_SpecificType(s64), m_Reg()))); 285 } 286 287 TEST(PatternMatchInstr, MatchCombinators) { 288 LLVMContext Context; 289 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 290 if (!TM) 291 return; 292 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 293 MachineFunction *MF = 294 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 295 SmallVector<unsigned, 4> Copies; 296 collectCopies(Copies, MF); 297 MachineBasicBlock *EntryMBB = &*MF->begin(); 298 MachineIRBuilder B(*MF); 299 MachineRegisterInfo &MRI = MF->getRegInfo(); 300 B.setInsertPt(*EntryMBB, EntryMBB->end()); 301 LLT s64 = LLT::scalar(64); 302 LLT s32 = LLT::scalar(32); 303 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 304 unsigned Src0, Src1; 305 bool match = 306 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 307 m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 308 ASSERT_TRUE(match); 309 ASSERT_EQ(Src0, Copies[0]); 310 ASSERT_EQ(Src1, Copies[1]); 311 // Check for s32 (which should fail). 312 match = 313 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 314 m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 315 ASSERT_FALSE(match); 316 match = 317 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 318 m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 319 ASSERT_TRUE(match); 320 ASSERT_EQ(Src0, Copies[0]); 321 ASSERT_EQ(Src1, Copies[1]); 322 } 323 } // namespace 324 325 int main(int argc, char **argv) { 326 ::testing::InitGoogleTest(&argc, argv); 327 initLLVM(); 328 return RUN_ALL_TESTS(); 329 } 330