1 //===- PatternMatchTest.cpp -----------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "llvm/CodeGen/GlobalISel/ConstantFoldingMIRBuilder.h" 11 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 12 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 13 #include "llvm/CodeGen/GlobalISel/Utils.h" 14 #include "llvm/CodeGen/MIRParser/MIRParser.h" 15 #include "llvm/CodeGen/MachineFunction.h" 16 #include "llvm/CodeGen/MachineModuleInfo.h" 17 #include "llvm/CodeGen/TargetFrameLowering.h" 18 #include "llvm/CodeGen/TargetInstrInfo.h" 19 #include "llvm/CodeGen/TargetLowering.h" 20 #include "llvm/CodeGen/TargetSubtargetInfo.h" 21 #include "llvm/Support/SourceMgr.h" 22 #include "llvm/Support/TargetRegistry.h" 23 #include "llvm/Support/TargetSelect.h" 24 #include "llvm/Target/TargetMachine.h" 25 #include "llvm/Target/TargetOptions.h" 26 #include "gtest/gtest.h" 27 28 using namespace llvm; 29 using namespace MIPatternMatch; 30 31 namespace { 32 33 void initLLVM() { 34 InitializeAllTargets(); 35 InitializeAllTargetMCs(); 36 InitializeAllAsmPrinters(); 37 InitializeAllAsmParsers(); 38 39 PassRegistry *Registry = PassRegistry::getPassRegistry(); 40 initializeCore(*Registry); 41 initializeCodeGen(*Registry); 42 } 43 44 /// Create a TargetMachine. As we lack a dedicated always available target for 45 /// unittests, we go for "AArch64". 46 std::unique_ptr<TargetMachine> createTargetMachine() { 47 Triple TargetTriple("aarch64--"); 48 std::string Error; 49 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); 50 if (!T) 51 return nullptr; 52 53 TargetOptions Options; 54 return std::unique_ptr<TargetMachine>(T->createTargetMachine( 55 "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive)); 56 } 57 58 std::unique_ptr<Module> parseMIR(LLVMContext &Context, 59 std::unique_ptr<MIRParser> &MIR, 60 const TargetMachine &TM, StringRef MIRCode, 61 const char *FuncName, MachineModuleInfo &MMI) { 62 SMDiagnostic Diagnostic; 63 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode); 64 MIR = createMIRParser(std::move(MBuffer), Context); 65 if (!MIR) 66 return nullptr; 67 68 std::unique_ptr<Module> M = MIR->parseIRModule(); 69 if (!M) 70 return nullptr; 71 72 M->setDataLayout(TM.createDataLayout()); 73 74 if (MIR->parseMachineFunctions(*M, MMI)) 75 return nullptr; 76 77 return M; 78 } 79 80 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>> 81 createDummyModule(LLVMContext &Context, const TargetMachine &TM, 82 StringRef MIRFunc) { 83 SmallString<512> S; 84 StringRef MIRString = (Twine(R"MIR( 85 --- 86 ... 87 name: func 88 registers: 89 - { id: 0, class: _ } 90 - { id: 1, class: _ } 91 - { id: 2, class: _ } 92 - { id: 3, class: _ } 93 body: | 94 bb.1: 95 %0(s64) = COPY $x0 96 %1(s64) = COPY $x1 97 %2(s64) = COPY $x2 98 )MIR") + Twine(MIRFunc) + Twine("...\n")) 99 .toNullTerminatedStringRef(S); 100 std::unique_ptr<MIRParser> MIR; 101 auto MMI = make_unique<MachineModuleInfo>(&TM); 102 std::unique_ptr<Module> M = 103 parseMIR(Context, MIR, TM, MIRString, "func", *MMI); 104 return make_pair(std::move(M), std::move(MMI)); 105 } 106 107 static MachineFunction *getMFFromMMI(const Module *M, 108 const MachineModuleInfo *MMI) { 109 Function *F = M->getFunction("func"); 110 auto *MF = MMI->getMachineFunction(*F); 111 return MF; 112 } 113 114 static void collectCopies(SmallVectorImpl<unsigned> &Copies, 115 MachineFunction *MF) { 116 for (auto &MBB : *MF) 117 for (MachineInstr &MI : MBB) { 118 if (MI.getOpcode() == TargetOpcode::COPY) 119 Copies.push_back(MI.getOperand(0).getReg()); 120 } 121 } 122 123 TEST(PatternMatchInstr, MatchIntConstant) { 124 LLVMContext Context; 125 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 126 if (!TM) 127 return; 128 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 129 MachineFunction *MF = 130 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 131 SmallVector<unsigned, 4> Copies; 132 collectCopies(Copies, MF); 133 MachineBasicBlock *EntryMBB = &*MF->begin(); 134 MachineIRBuilder B(*MF); 135 MachineRegisterInfo &MRI = MF->getRegInfo(); 136 B.setInsertPt(*EntryMBB, EntryMBB->end()); 137 auto MIBCst = B.buildConstant(LLT::scalar(64), 42); 138 int64_t Cst; 139 bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst)); 140 ASSERT_TRUE(match); 141 ASSERT_EQ(Cst, 42); 142 } 143 144 TEST(PatternMatchInstr, MatchBinaryOp) { 145 LLVMContext Context; 146 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 147 if (!TM) 148 return; 149 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 150 MachineFunction *MF = 151 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 152 SmallVector<unsigned, 4> Copies; 153 collectCopies(Copies, MF); 154 MachineBasicBlock *EntryMBB = &*MF->begin(); 155 MachineIRBuilder B(*MF); 156 MachineRegisterInfo &MRI = MF->getRegInfo(); 157 B.setInsertPt(*EntryMBB, EntryMBB->end()); 158 LLT s64 = LLT::scalar(64); 159 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 160 // Test case for no bind. 161 bool match = 162 mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg())); 163 ASSERT_TRUE(match); 164 unsigned Src0, Src1, Src2; 165 match = mi_match(MIBAdd->getOperand(0).getReg(), MRI, 166 m_GAdd(m_Reg(Src0), m_Reg(Src1))); 167 ASSERT_TRUE(match); 168 ASSERT_EQ(Src0, Copies[0]); 169 ASSERT_EQ(Src1, Copies[1]); 170 171 // Build MUL(ADD %0, %1), %2 172 auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]); 173 174 // Try to match MUL. 175 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 176 m_GMul(m_Reg(Src0), m_Reg(Src1))); 177 ASSERT_TRUE(match); 178 ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg()); 179 ASSERT_EQ(Src1, Copies[2]); 180 181 // Try to match MUL(ADD) 182 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 183 m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2))); 184 ASSERT_TRUE(match); 185 ASSERT_EQ(Src0, Copies[0]); 186 ASSERT_EQ(Src1, Copies[1]); 187 ASSERT_EQ(Src2, Copies[2]); 188 189 // Test Commutativity. 190 auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42)); 191 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate 192 // commutativity. 193 int64_t Cst; 194 match = mi_match(MIBMul2->getOperand(0).getReg(), MRI, 195 m_GMul(m_ICst(Cst), m_Reg(Src0))); 196 ASSERT_TRUE(match); 197 ASSERT_EQ(Cst, 42); 198 ASSERT_EQ(Src0, Copies[0]); 199 200 // Make sure commutative doesn't work with something like SUB. 201 auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42)); 202 match = mi_match(MIBSub->getOperand(0).getReg(), MRI, 203 m_GSub(m_ICst(Cst), m_Reg(Src0))); 204 ASSERT_FALSE(match); 205 206 auto MIBFMul = B.buildInstr(TargetOpcode::G_FMUL, s64, Copies[0], 207 B.buildConstant(s64, 42)); 208 // Match and test commutativity for FMUL. 209 match = mi_match(MIBFMul->getOperand(0).getReg(), MRI, 210 m_GFMul(m_ICst(Cst), m_Reg(Src0))); 211 ASSERT_TRUE(match); 212 ASSERT_EQ(Cst, 42); 213 ASSERT_EQ(Src0, Copies[0]); 214 215 // Build AND %0, %1 216 auto MIBAnd = B.buildAnd(s64, Copies[0], Copies[1]); 217 // Try to match AND. 218 match = mi_match(MIBAnd->getOperand(0).getReg(), MRI, 219 m_GAnd(m_Reg(Src0), m_Reg(Src1))); 220 ASSERT_TRUE(match); 221 ASSERT_EQ(Src0, Copies[0]); 222 ASSERT_EQ(Src1, Copies[1]); 223 224 // Build OR %0, %1 225 auto MIBOr = B.buildOr(s64, Copies[0], Copies[1]); 226 // Try to match OR. 227 match = mi_match(MIBOr->getOperand(0).getReg(), MRI, 228 m_GOr(m_Reg(Src0), m_Reg(Src1))); 229 ASSERT_TRUE(match); 230 ASSERT_EQ(Src0, Copies[0]); 231 ASSERT_EQ(Src1, Copies[1]); 232 233 // Try to use the FoldableInstructionsBuilder to build binary ops. 234 ConstantFoldingMIRBuilder CFB(B.getState()); 235 LLT s32 = LLT::scalar(32); 236 auto MIBCAdd = 237 CFB.buildAdd(s32, CFB.buildConstant(s32, 0), CFB.buildConstant(s32, 1)); 238 // This should be a constant now. 239 match = mi_match(MIBCAdd->getOperand(0).getReg(), MRI, m_ICst(Cst)); 240 ASSERT_TRUE(match); 241 ASSERT_EQ(Cst, 1); 242 auto MIBCAdd1 = 243 CFB.buildInstr(TargetOpcode::G_ADD, s32, CFB.buildConstant(s32, 0), 244 CFB.buildConstant(s32, 1)); 245 // This should be a constant now. 246 match = mi_match(MIBCAdd1->getOperand(0).getReg(), MRI, m_ICst(Cst)); 247 ASSERT_TRUE(match); 248 ASSERT_EQ(Cst, 1); 249 250 // Try one of the other constructors of MachineIRBuilder to make sure it's 251 // compatible. 252 ConstantFoldingMIRBuilder CFB1(*MF); 253 CFB1.setInsertPt(*EntryMBB, EntryMBB->end()); 254 auto MIBCSub = 255 CFB1.buildInstr(TargetOpcode::G_SUB, s32, CFB1.buildConstant(s32, 1), 256 CFB1.buildConstant(s32, 1)); 257 // This should be a constant now. 258 match = mi_match(MIBCSub->getOperand(0).getReg(), MRI, m_ICst(Cst)); 259 ASSERT_TRUE(match); 260 ASSERT_EQ(Cst, 0); 261 } 262 263 TEST(PatternMatchInstr, MatchFPUnaryOp) { 264 LLVMContext Context; 265 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 266 if (!TM) 267 return; 268 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 269 MachineFunction *MF = 270 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 271 SmallVector<unsigned, 4> Copies; 272 collectCopies(Copies, MF); 273 MachineBasicBlock *EntryMBB = &*MF->begin(); 274 MachineIRBuilder B(*MF); 275 MachineRegisterInfo &MRI = MF->getRegInfo(); 276 B.setInsertPt(*EntryMBB, EntryMBB->end()); 277 278 // Truncate s64 to s32. 279 LLT s32 = LLT::scalar(32); 280 auto Copy0s32 = B.buildFPTrunc(s32, Copies[0]); 281 282 // Match G_FABS. 283 auto MIBFabs = B.buildInstr(TargetOpcode::G_FABS, s32, Copy0s32); 284 bool match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg())); 285 ASSERT_TRUE(match); 286 unsigned Src; 287 match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg(Src))); 288 ASSERT_TRUE(match); 289 ASSERT_EQ(Src, Copy0s32->getOperand(0).getReg()); 290 291 // Build and match FConstant. 292 auto MIBFCst = B.buildFConstant(s32, .5); 293 const ConstantFP *TmpFP{}; 294 match = mi_match(MIBFCst->getOperand(0).getReg(), MRI, m_GFCst(TmpFP)); 295 ASSERT_TRUE(match); 296 ASSERT_TRUE(TmpFP); 297 APFloat APF((float).5); 298 auto *CFP = ConstantFP::get(Context, APF); 299 ASSERT_EQ(CFP, TmpFP); 300 301 // Build double float. 302 LLT s64 = LLT::scalar(64); 303 auto MIBFCst64 = B.buildFConstant(s64, .5); 304 const ConstantFP *TmpFP64{}; 305 match = mi_match(MIBFCst64->getOperand(0).getReg(), MRI, m_GFCst(TmpFP64)); 306 ASSERT_TRUE(match); 307 ASSERT_TRUE(TmpFP64); 308 APFloat APF64(.5); 309 auto CFP64 = ConstantFP::get(Context, APF64); 310 ASSERT_EQ(CFP64, TmpFP64); 311 ASSERT_NE(TmpFP64, TmpFP); 312 313 // Build half float. 314 LLT s16 = LLT::scalar(16); 315 auto MIBFCst16 = B.buildFConstant(s16, .5); 316 const ConstantFP *TmpFP16{}; 317 match = mi_match(MIBFCst16->getOperand(0).getReg(), MRI, m_GFCst(TmpFP16)); 318 ASSERT_TRUE(match); 319 ASSERT_TRUE(TmpFP16); 320 bool Ignored; 321 APFloat APF16(.5); 322 APF16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); 323 auto CFP16 = ConstantFP::get(Context, APF16); 324 ASSERT_EQ(TmpFP16, CFP16); 325 ASSERT_NE(TmpFP16, TmpFP); 326 } 327 328 TEST(PatternMatchInstr, MatchExtendsTrunc) { 329 LLVMContext Context; 330 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 331 if (!TM) 332 return; 333 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 334 MachineFunction *MF = 335 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 336 SmallVector<unsigned, 4> Copies; 337 collectCopies(Copies, MF); 338 MachineBasicBlock *EntryMBB = &*MF->begin(); 339 MachineIRBuilder B(*MF); 340 MachineRegisterInfo &MRI = MF->getRegInfo(); 341 B.setInsertPt(*EntryMBB, EntryMBB->end()); 342 LLT s64 = LLT::scalar(64); 343 LLT s32 = LLT::scalar(32); 344 345 auto MIBTrunc = B.buildTrunc(s32, Copies[0]); 346 auto MIBAExt = B.buildAnyExt(s64, MIBTrunc); 347 auto MIBZExt = B.buildZExt(s64, MIBTrunc); 348 auto MIBSExt = B.buildSExt(s64, MIBTrunc); 349 unsigned Src0; 350 bool match = 351 mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0))); 352 ASSERT_TRUE(match); 353 ASSERT_EQ(Src0, Copies[0]); 354 match = 355 mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0))); 356 ASSERT_TRUE(match); 357 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 358 359 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0))); 360 ASSERT_TRUE(match); 361 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 362 363 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0))); 364 ASSERT_TRUE(match); 365 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 366 367 // Match ext(trunc src) 368 match = mi_match(MIBAExt->getOperand(0).getReg(), MRI, 369 m_GAnyExt(m_GTrunc(m_Reg(Src0)))); 370 ASSERT_TRUE(match); 371 ASSERT_EQ(Src0, Copies[0]); 372 373 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, 374 m_GSExt(m_GTrunc(m_Reg(Src0)))); 375 ASSERT_TRUE(match); 376 ASSERT_EQ(Src0, Copies[0]); 377 378 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, 379 m_GZExt(m_GTrunc(m_Reg(Src0)))); 380 ASSERT_TRUE(match); 381 ASSERT_EQ(Src0, Copies[0]); 382 } 383 384 TEST(PatternMatchInstr, MatchSpecificType) { 385 LLVMContext Context; 386 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 387 if (!TM) 388 return; 389 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 390 MachineFunction *MF = 391 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 392 SmallVector<unsigned, 4> Copies; 393 collectCopies(Copies, MF); 394 MachineBasicBlock *EntryMBB = &*MF->begin(); 395 MachineIRBuilder B(*MF); 396 MachineRegisterInfo &MRI = MF->getRegInfo(); 397 B.setInsertPt(*EntryMBB, EntryMBB->end()); 398 399 // Try to match a 64bit add. 400 LLT s64 = LLT::scalar(64); 401 LLT s32 = LLT::scalar(32); 402 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 403 ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 404 m_GAdd(m_SpecificType(s32), m_Reg()))); 405 ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 406 m_GAdd(m_SpecificType(s64), m_Reg()))); 407 408 // Try to match the destination type of a bitcast. 409 LLT v2s32 = LLT::vector(2, 32); 410 auto MIBCast = B.buildCast(v2s32, Copies[0]); 411 ASSERT_TRUE( 412 mi_match(MIBCast->getOperand(0).getReg(), MRI, m_GBitcast(m_Reg()))); 413 ASSERT_TRUE( 414 mi_match(MIBCast->getOperand(0).getReg(), MRI, m_SpecificType(v2s32))); 415 ASSERT_TRUE( 416 mi_match(MIBCast->getOperand(1).getReg(), MRI, m_SpecificType(s64))); 417 418 // Build a PTRToInt and INTTOPTR and match and test them. 419 LLT PtrTy = LLT::pointer(0, 64); 420 auto MIBIntToPtr = B.buildCast(PtrTy, Copies[0]); 421 auto MIBPtrToInt = B.buildCast(s64, MIBIntToPtr); 422 unsigned Src0; 423 424 // match the ptrtoint(inttoptr reg) 425 bool match = mi_match(MIBPtrToInt->getOperand(0).getReg(), MRI, 426 m_GPtrToInt(m_GIntToPtr(m_Reg(Src0)))); 427 ASSERT_TRUE(match); 428 ASSERT_EQ(Src0, Copies[0]); 429 } 430 431 TEST(PatternMatchInstr, MatchCombinators) { 432 LLVMContext Context; 433 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 434 if (!TM) 435 return; 436 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 437 MachineFunction *MF = 438 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 439 SmallVector<unsigned, 4> Copies; 440 collectCopies(Copies, MF); 441 MachineBasicBlock *EntryMBB = &*MF->begin(); 442 MachineIRBuilder B(*MF); 443 MachineRegisterInfo &MRI = MF->getRegInfo(); 444 B.setInsertPt(*EntryMBB, EntryMBB->end()); 445 LLT s64 = LLT::scalar(64); 446 LLT s32 = LLT::scalar(32); 447 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 448 unsigned Src0, Src1; 449 bool match = 450 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 451 m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 452 ASSERT_TRUE(match); 453 ASSERT_EQ(Src0, Copies[0]); 454 ASSERT_EQ(Src1, Copies[1]); 455 // Check for s32 (which should fail). 456 match = 457 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 458 m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 459 ASSERT_FALSE(match); 460 match = 461 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 462 m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 463 ASSERT_TRUE(match); 464 ASSERT_EQ(Src0, Copies[0]); 465 ASSERT_EQ(Src1, Copies[1]); 466 467 // Match a case where none of the predicates hold true. 468 match = mi_match( 469 MIBAdd->getOperand(0).getReg(), MRI, 470 m_any_of(m_SpecificType(LLT::scalar(16)), m_GSub(m_Reg(), m_Reg()))); 471 ASSERT_FALSE(match); 472 } 473 } // namespace 474 475 int main(int argc, char **argv) { 476 ::testing::InitGoogleTest(&argc, argv); 477 initLLVM(); 478 return RUN_ALL_TESTS(); 479 } 480