1 //===- PatternMatchTest.cpp -----------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 11 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 12 #include "llvm/CodeGen/GlobalISel/Utils.h" 13 #include "llvm/CodeGen/MIRParser/MIRParser.h" 14 #include "llvm/CodeGen/MachineFunction.h" 15 #include "llvm/CodeGen/MachineModuleInfo.h" 16 #include "llvm/CodeGen/TargetFrameLowering.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 #include "llvm/CodeGen/TargetLowering.h" 19 #include "llvm/CodeGen/TargetSubtargetInfo.h" 20 #include "llvm/Support/SourceMgr.h" 21 #include "llvm/Support/TargetRegistry.h" 22 #include "llvm/Support/TargetSelect.h" 23 #include "llvm/Target/TargetMachine.h" 24 #include "llvm/Target/TargetOptions.h" 25 #include "gtest/gtest.h" 26 27 using namespace llvm; 28 using namespace MIPatternMatch; 29 30 namespace { 31 32 void initLLVM() { 33 InitializeAllTargets(); 34 InitializeAllTargetMCs(); 35 InitializeAllAsmPrinters(); 36 InitializeAllAsmParsers(); 37 38 PassRegistry *Registry = PassRegistry::getPassRegistry(); 39 initializeCore(*Registry); 40 initializeCodeGen(*Registry); 41 } 42 43 /// Create a TargetMachine. As we lack a dedicated always available target for 44 /// unittests, we go for "AArch64". 45 std::unique_ptr<TargetMachine> createTargetMachine() { 46 Triple TargetTriple("aarch64--"); 47 std::string Error; 48 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); 49 if (!T) 50 return nullptr; 51 52 TargetOptions Options; 53 return std::unique_ptr<TargetMachine>(T->createTargetMachine( 54 "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive)); 55 } 56 57 std::unique_ptr<Module> parseMIR(LLVMContext &Context, 58 std::unique_ptr<MIRParser> &MIR, 59 const TargetMachine &TM, StringRef MIRCode, 60 const char *FuncName, MachineModuleInfo &MMI) { 61 SMDiagnostic Diagnostic; 62 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode); 63 MIR = createMIRParser(std::move(MBuffer), Context); 64 if (!MIR) 65 return nullptr; 66 67 std::unique_ptr<Module> M = MIR->parseIRModule(); 68 if (!M) 69 return nullptr; 70 71 M->setDataLayout(TM.createDataLayout()); 72 73 if (MIR->parseMachineFunctions(*M, MMI)) 74 return nullptr; 75 76 return M; 77 } 78 79 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>> 80 createDummyModule(LLVMContext &Context, const TargetMachine &TM, 81 StringRef MIRFunc) { 82 SmallString<512> S; 83 StringRef MIRString = (Twine(R"MIR( 84 --- 85 ... 86 name: func 87 registers: 88 - { id: 0, class: _ } 89 - { id: 1, class: _ } 90 - { id: 2, class: _ } 91 - { id: 3, class: _ } 92 body: | 93 bb.1: 94 %0(s64) = COPY $x0 95 %1(s64) = COPY $x1 96 %2(s64) = COPY $x2 97 )MIR") + Twine(MIRFunc) + Twine("...\n")) 98 .toNullTerminatedStringRef(S); 99 std::unique_ptr<MIRParser> MIR; 100 auto MMI = make_unique<MachineModuleInfo>(&TM); 101 std::unique_ptr<Module> M = 102 parseMIR(Context, MIR, TM, MIRString, "func", *MMI); 103 return make_pair(std::move(M), std::move(MMI)); 104 } 105 106 static MachineFunction *getMFFromMMI(const Module *M, 107 const MachineModuleInfo *MMI) { 108 Function *F = M->getFunction("func"); 109 auto *MF = MMI->getMachineFunction(*F); 110 return MF; 111 } 112 113 static void collectCopies(SmallVectorImpl<unsigned> &Copies, 114 MachineFunction *MF) { 115 for (auto &MBB : *MF) 116 for (MachineInstr &MI : MBB) { 117 if (MI.getOpcode() == TargetOpcode::COPY) 118 Copies.push_back(MI.getOperand(0).getReg()); 119 } 120 } 121 122 TEST(PatternMatchInstr, MatchIntConstant) { 123 LLVMContext Context; 124 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 125 if (!TM) 126 return; 127 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 128 MachineFunction *MF = 129 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 130 SmallVector<unsigned, 4> Copies; 131 collectCopies(Copies, MF); 132 MachineBasicBlock *EntryMBB = &*MF->begin(); 133 MachineIRBuilder B(*MF); 134 MachineRegisterInfo &MRI = MF->getRegInfo(); 135 B.setInsertPt(*EntryMBB, EntryMBB->end()); 136 auto MIBCst = B.buildConstant(LLT::scalar(64), 42); 137 uint64_t Cst; 138 bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst)); 139 ASSERT_TRUE(match); 140 ASSERT_EQ(Cst, (uint64_t)42); 141 } 142 143 TEST(PatternMatchInstr, MatchBinaryOp) { 144 LLVMContext Context; 145 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 146 if (!TM) 147 return; 148 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 149 MachineFunction *MF = 150 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 151 SmallVector<unsigned, 4> Copies; 152 collectCopies(Copies, MF); 153 MachineBasicBlock *EntryMBB = &*MF->begin(); 154 MachineIRBuilder B(*MF); 155 MachineRegisterInfo &MRI = MF->getRegInfo(); 156 B.setInsertPt(*EntryMBB, EntryMBB->end()); 157 LLT s64 = LLT::scalar(64); 158 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 159 // Test case for no bind. 160 bool match = 161 mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg())); 162 ASSERT_TRUE(match); 163 unsigned Src0, Src1, Src2; 164 match = mi_match(MIBAdd->getOperand(0).getReg(), MRI, 165 m_GAdd(m_Reg(Src0), m_Reg(Src1))); 166 ASSERT_TRUE(match); 167 ASSERT_EQ(Src0, Copies[0]); 168 ASSERT_EQ(Src1, Copies[1]); 169 170 // Build MUL(ADD %0, %1), %2 171 auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]); 172 173 // Try to match MUL. 174 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 175 m_GMul(m_Reg(Src0), m_Reg(Src1))); 176 ASSERT_TRUE(match); 177 ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg()); 178 ASSERT_EQ(Src1, Copies[2]); 179 180 // Try to match MUL(ADD) 181 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 182 m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2))); 183 ASSERT_TRUE(match); 184 ASSERT_EQ(Src0, Copies[0]); 185 ASSERT_EQ(Src1, Copies[1]); 186 ASSERT_EQ(Src2, Copies[2]); 187 188 // Test Commutativity. 189 auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42)); 190 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate 191 // commutativity. 192 uint64_t Cst; 193 match = mi_match(MIBMul2->getOperand(0).getReg(), MRI, 194 m_GMul(m_ICst(Cst), m_Reg(Src0))); 195 ASSERT_TRUE(match); 196 ASSERT_EQ(Cst, (uint64_t)42); 197 ASSERT_EQ(Src0, Copies[0]); 198 199 // Make sure commutative doesn't work with something like SUB. 200 auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42)); 201 match = mi_match(MIBSub->getOperand(0).getReg(), MRI, 202 m_GSub(m_ICst(Cst), m_Reg(Src0))); 203 ASSERT_FALSE(match); 204 205 auto MIBFMul = B.buildInstr(TargetOpcode::G_FMUL, s64, Copies[0], 206 B.buildConstant(s64, 42)); 207 // Match and test commutativity for FMUL. 208 match = mi_match(MIBFMul->getOperand(0).getReg(), MRI, 209 m_GFMul(m_ICst(Cst), m_Reg(Src0))); 210 ASSERT_TRUE(match); 211 ASSERT_EQ(Cst, (uint64_t)42); 212 ASSERT_EQ(Src0, Copies[0]); 213 214 // Build AND %0, %1 215 auto MIBAnd = B.buildAnd(s64, Copies[0], Copies[1]); 216 // Try to match AND. 217 match = mi_match(MIBAnd->getOperand(0).getReg(), MRI, 218 m_GAnd(m_Reg(Src0), m_Reg(Src1))); 219 ASSERT_TRUE(match); 220 ASSERT_EQ(Src0, Copies[0]); 221 ASSERT_EQ(Src1, Copies[1]); 222 223 // Build OR %0, %1 224 auto MIBOr = B.buildOr(s64, Copies[0], Copies[1]); 225 // Try to match OR. 226 match = mi_match(MIBOr->getOperand(0).getReg(), MRI, 227 m_GOr(m_Reg(Src0), m_Reg(Src1))); 228 ASSERT_TRUE(match); 229 ASSERT_EQ(Src0, Copies[0]); 230 ASSERT_EQ(Src1, Copies[1]); 231 } 232 233 TEST(PatternMatchInstr, MatchFPUnaryOp) { 234 LLVMContext Context; 235 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 236 if (!TM) 237 return; 238 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 239 MachineFunction *MF = 240 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 241 SmallVector<unsigned, 4> Copies; 242 collectCopies(Copies, MF); 243 MachineBasicBlock *EntryMBB = &*MF->begin(); 244 MachineIRBuilder B(*MF); 245 MachineRegisterInfo &MRI = MF->getRegInfo(); 246 B.setInsertPt(*EntryMBB, EntryMBB->end()); 247 248 // Truncate s64 to s32. 249 LLT s32 = LLT::scalar(32); 250 auto Copy0s32 = B.buildFPTrunc(s32, Copies[0]); 251 252 // Match G_FABS. 253 auto MIBFabs = B.buildInstr(TargetOpcode::G_FABS, s32, Copy0s32); 254 bool match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg())); 255 ASSERT_TRUE(match); 256 unsigned Src; 257 match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg(Src))); 258 ASSERT_TRUE(match); 259 ASSERT_EQ(Src, Copy0s32->getOperand(0).getReg()); 260 } 261 262 TEST(PatternMatchInstr, MatchExtendsTrunc) { 263 LLVMContext Context; 264 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 265 if (!TM) 266 return; 267 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 268 MachineFunction *MF = 269 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 270 SmallVector<unsigned, 4> Copies; 271 collectCopies(Copies, MF); 272 MachineBasicBlock *EntryMBB = &*MF->begin(); 273 MachineIRBuilder B(*MF); 274 MachineRegisterInfo &MRI = MF->getRegInfo(); 275 B.setInsertPt(*EntryMBB, EntryMBB->end()); 276 LLT s64 = LLT::scalar(64); 277 LLT s32 = LLT::scalar(32); 278 279 auto MIBTrunc = B.buildTrunc(s32, Copies[0]); 280 auto MIBAExt = B.buildAnyExt(s64, MIBTrunc); 281 auto MIBZExt = B.buildZExt(s64, MIBTrunc); 282 auto MIBSExt = B.buildSExt(s64, MIBTrunc); 283 unsigned Src0; 284 bool match = 285 mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0))); 286 ASSERT_TRUE(match); 287 ASSERT_EQ(Src0, Copies[0]); 288 match = 289 mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0))); 290 ASSERT_TRUE(match); 291 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 292 293 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0))); 294 ASSERT_TRUE(match); 295 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 296 297 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0))); 298 ASSERT_TRUE(match); 299 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 300 301 // Match ext(trunc src) 302 match = mi_match(MIBAExt->getOperand(0).getReg(), MRI, 303 m_GAnyExt(m_GTrunc(m_Reg(Src0)))); 304 ASSERT_TRUE(match); 305 ASSERT_EQ(Src0, Copies[0]); 306 307 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, 308 m_GSExt(m_GTrunc(m_Reg(Src0)))); 309 ASSERT_TRUE(match); 310 ASSERT_EQ(Src0, Copies[0]); 311 312 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, 313 m_GZExt(m_GTrunc(m_Reg(Src0)))); 314 ASSERT_TRUE(match); 315 ASSERT_EQ(Src0, Copies[0]); 316 } 317 318 TEST(PatternMatchInstr, MatchSpecificType) { 319 LLVMContext Context; 320 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 321 if (!TM) 322 return; 323 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 324 MachineFunction *MF = 325 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 326 SmallVector<unsigned, 4> Copies; 327 collectCopies(Copies, MF); 328 MachineBasicBlock *EntryMBB = &*MF->begin(); 329 MachineIRBuilder B(*MF); 330 MachineRegisterInfo &MRI = MF->getRegInfo(); 331 B.setInsertPt(*EntryMBB, EntryMBB->end()); 332 333 // Try to match a 64bit add. 334 LLT s64 = LLT::scalar(64); 335 LLT s32 = LLT::scalar(32); 336 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 337 ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 338 m_GAdd(m_SpecificType(s32), m_Reg()))); 339 ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 340 m_GAdd(m_SpecificType(s64), m_Reg()))); 341 342 // Try to match the destination type of a bitcast. 343 LLT v2s32 = LLT::vector(2, 32); 344 auto MIBCast = B.buildCast(v2s32, Copies[0]); 345 ASSERT_TRUE( 346 mi_match(MIBCast->getOperand(0).getReg(), MRI, m_GBitcast(m_Reg()))); 347 ASSERT_TRUE( 348 mi_match(MIBCast->getOperand(0).getReg(), MRI, m_SpecificType(v2s32))); 349 ASSERT_TRUE( 350 mi_match(MIBCast->getOperand(1).getReg(), MRI, m_SpecificType(s64))); 351 352 // Build a PTRToInt and INTTOPTR and match and test them. 353 LLT PtrTy = LLT::pointer(0, 64); 354 auto MIBIntToPtr = B.buildCast(PtrTy, Copies[0]); 355 auto MIBPtrToInt = B.buildCast(s64, MIBIntToPtr); 356 unsigned Src0; 357 358 // match the ptrtoint(inttoptr reg) 359 bool match = mi_match(MIBPtrToInt->getOperand(0).getReg(), MRI, 360 m_GPtrToInt(m_GIntToPtr(m_Reg(Src0)))); 361 ASSERT_TRUE(match); 362 ASSERT_EQ(Src0, Copies[0]); 363 } 364 365 TEST(PatternMatchInstr, MatchCombinators) { 366 LLVMContext Context; 367 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 368 if (!TM) 369 return; 370 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 371 MachineFunction *MF = 372 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 373 SmallVector<unsigned, 4> Copies; 374 collectCopies(Copies, MF); 375 MachineBasicBlock *EntryMBB = &*MF->begin(); 376 MachineIRBuilder B(*MF); 377 MachineRegisterInfo &MRI = MF->getRegInfo(); 378 B.setInsertPt(*EntryMBB, EntryMBB->end()); 379 LLT s64 = LLT::scalar(64); 380 LLT s32 = LLT::scalar(32); 381 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 382 unsigned Src0, Src1; 383 bool match = 384 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 385 m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 386 ASSERT_TRUE(match); 387 ASSERT_EQ(Src0, Copies[0]); 388 ASSERT_EQ(Src1, Copies[1]); 389 // Check for s32 (which should fail). 390 match = 391 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 392 m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 393 ASSERT_FALSE(match); 394 match = 395 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 396 m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 397 ASSERT_TRUE(match); 398 ASSERT_EQ(Src0, Copies[0]); 399 ASSERT_EQ(Src1, Copies[1]); 400 401 // Match a case where none of the predicates hold true. 402 match = mi_match( 403 MIBAdd->getOperand(0).getReg(), MRI, 404 m_any_of(m_SpecificType(LLT::scalar(16)), m_GSub(m_Reg(), m_Reg()))); 405 ASSERT_FALSE(match); 406 } 407 } // namespace 408 409 int main(int argc, char **argv) { 410 ::testing::InitGoogleTest(&argc, argv); 411 initLLVM(); 412 return RUN_ALL_TESTS(); 413 } 414