1 //===- PatternMatchTest.cpp -----------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 11 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 12 #include "llvm/CodeGen/GlobalISel/Utils.h" 13 #include "llvm/CodeGen/MIRParser/MIRParser.h" 14 #include "llvm/CodeGen/MachineFunction.h" 15 #include "llvm/CodeGen/MachineModuleInfo.h" 16 #include "llvm/CodeGen/TargetFrameLowering.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 #include "llvm/CodeGen/TargetLowering.h" 19 #include "llvm/CodeGen/TargetSubtargetInfo.h" 20 #include "llvm/Support/SourceMgr.h" 21 #include "llvm/Support/TargetRegistry.h" 22 #include "llvm/Support/TargetSelect.h" 23 #include "llvm/Target/TargetMachine.h" 24 #include "llvm/Target/TargetOptions.h" 25 #include "gtest/gtest.h" 26 27 using namespace llvm; 28 using namespace MIPatternMatch; 29 30 namespace { 31 32 void initLLVM() { 33 InitializeAllTargets(); 34 InitializeAllTargetMCs(); 35 InitializeAllAsmPrinters(); 36 InitializeAllAsmParsers(); 37 38 PassRegistry *Registry = PassRegistry::getPassRegistry(); 39 initializeCore(*Registry); 40 initializeCodeGen(*Registry); 41 } 42 43 /// Create a TargetMachine. As we lack a dedicated always available target for 44 /// unittests, we go for "AArch64". 45 std::unique_ptr<TargetMachine> createTargetMachine() { 46 Triple TargetTriple("aarch64--"); 47 std::string Error; 48 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); 49 if (!T) 50 return nullptr; 51 52 TargetOptions Options; 53 return std::unique_ptr<TargetMachine>(T->createTargetMachine( 54 "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive)); 55 } 56 57 std::unique_ptr<Module> parseMIR(LLVMContext &Context, 58 std::unique_ptr<MIRParser> &MIR, 59 const TargetMachine &TM, StringRef MIRCode, 60 const char *FuncName, MachineModuleInfo &MMI) { 61 SMDiagnostic Diagnostic; 62 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode); 63 MIR = createMIRParser(std::move(MBuffer), Context); 64 if (!MIR) 65 return nullptr; 66 67 std::unique_ptr<Module> M = MIR->parseIRModule(); 68 if (!M) 69 return nullptr; 70 71 M->setDataLayout(TM.createDataLayout()); 72 73 if (MIR->parseMachineFunctions(*M, MMI)) 74 return nullptr; 75 76 return M; 77 } 78 79 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>> 80 createDummyModule(LLVMContext &Context, const TargetMachine &TM, 81 StringRef MIRFunc) { 82 SmallString<512> S; 83 StringRef MIRString = (Twine(R"MIR( 84 --- 85 ... 86 name: func 87 registers: 88 - { id: 0, class: _ } 89 - { id: 1, class: _ } 90 - { id: 2, class: _ } 91 - { id: 3, class: _ } 92 body: | 93 bb.1: 94 %0(s64) = COPY $x0 95 %1(s64) = COPY $x1 96 %2(s64) = COPY $x2 97 )MIR") + Twine(MIRFunc) + Twine("...\n")) 98 .toNullTerminatedStringRef(S); 99 std::unique_ptr<MIRParser> MIR; 100 auto MMI = make_unique<MachineModuleInfo>(&TM); 101 std::unique_ptr<Module> M = 102 parseMIR(Context, MIR, TM, MIRString, "func", *MMI); 103 return make_pair(std::move(M), std::move(MMI)); 104 } 105 106 static MachineFunction *getMFFromMMI(const Module *M, 107 const MachineModuleInfo *MMI) { 108 Function *F = M->getFunction("func"); 109 auto *MF = MMI->getMachineFunction(*F); 110 return MF; 111 } 112 113 static void collectCopies(SmallVectorImpl<unsigned> &Copies, 114 MachineFunction *MF) { 115 for (auto &MBB : *MF) 116 for (MachineInstr &MI : MBB) { 117 if (MI.getOpcode() == TargetOpcode::COPY) 118 Copies.push_back(MI.getOperand(0).getReg()); 119 } 120 } 121 122 TEST(PatternMatchInstr, MatchIntConstant) { 123 LLVMContext Context; 124 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 125 if (!TM) 126 return; 127 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 128 MachineFunction *MF = 129 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 130 SmallVector<unsigned, 4> Copies; 131 collectCopies(Copies, MF); 132 MachineBasicBlock *EntryMBB = &*MF->begin(); 133 MachineIRBuilder B(*MF); 134 MachineRegisterInfo &MRI = MF->getRegInfo(); 135 B.setInsertPt(*EntryMBB, EntryMBB->end()); 136 auto MIBCst = B.buildConstant(LLT::scalar(64), 42); 137 int64_t Cst; 138 bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst)); 139 ASSERT_TRUE(match); 140 ASSERT_EQ(Cst, 42); 141 } 142 143 TEST(PatternMatchInstr, MatchBinaryOp) { 144 LLVMContext Context; 145 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 146 if (!TM) 147 return; 148 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 149 MachineFunction *MF = 150 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 151 SmallVector<unsigned, 4> Copies; 152 collectCopies(Copies, MF); 153 MachineBasicBlock *EntryMBB = &*MF->begin(); 154 MachineIRBuilder B(*MF); 155 MachineRegisterInfo &MRI = MF->getRegInfo(); 156 B.setInsertPt(*EntryMBB, EntryMBB->end()); 157 LLT s64 = LLT::scalar(64); 158 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 159 // Test case for no bind. 160 bool match = 161 mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg())); 162 ASSERT_TRUE(match); 163 unsigned Src0, Src1, Src2; 164 match = mi_match(MIBAdd->getOperand(0).getReg(), MRI, 165 m_GAdd(m_Reg(Src0), m_Reg(Src1))); 166 ASSERT_TRUE(match); 167 ASSERT_EQ(Src0, Copies[0]); 168 ASSERT_EQ(Src1, Copies[1]); 169 170 // Build MUL(ADD %0, %1), %2 171 auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]); 172 173 // Try to match MUL. 174 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 175 m_GMul(m_Reg(Src0), m_Reg(Src1))); 176 ASSERT_TRUE(match); 177 ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg()); 178 ASSERT_EQ(Src1, Copies[2]); 179 180 // Try to match MUL(ADD) 181 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 182 m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2))); 183 ASSERT_TRUE(match); 184 ASSERT_EQ(Src0, Copies[0]); 185 ASSERT_EQ(Src1, Copies[1]); 186 ASSERT_EQ(Src2, Copies[2]); 187 188 // Test Commutativity. 189 auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42)); 190 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate 191 // commutativity. 192 int64_t Cst; 193 match = mi_match(MIBMul2->getOperand(0).getReg(), MRI, 194 m_GMul(m_ICst(Cst), m_Reg(Src0))); 195 ASSERT_TRUE(match); 196 ASSERT_EQ(Cst, 42); 197 ASSERT_EQ(Src0, Copies[0]); 198 199 // Make sure commutative doesn't work with something like SUB. 200 auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42)); 201 match = mi_match(MIBSub->getOperand(0).getReg(), MRI, 202 m_GSub(m_ICst(Cst), m_Reg(Src0))); 203 ASSERT_FALSE(match); 204 205 auto MIBFMul = B.buildInstr(TargetOpcode::G_FMUL, s64, Copies[0], 206 B.buildConstant(s64, 42)); 207 // Match and test commutativity for FMUL. 208 match = mi_match(MIBFMul->getOperand(0).getReg(), MRI, 209 m_GFMul(m_ICst(Cst), m_Reg(Src0))); 210 ASSERT_TRUE(match); 211 ASSERT_EQ(Cst, 42); 212 ASSERT_EQ(Src0, Copies[0]); 213 214 // Build AND %0, %1 215 auto MIBAnd = B.buildAnd(s64, Copies[0], Copies[1]); 216 // Try to match AND. 217 match = mi_match(MIBAnd->getOperand(0).getReg(), MRI, 218 m_GAnd(m_Reg(Src0), m_Reg(Src1))); 219 ASSERT_TRUE(match); 220 ASSERT_EQ(Src0, Copies[0]); 221 ASSERT_EQ(Src1, Copies[1]); 222 223 // Build OR %0, %1 224 auto MIBOr = B.buildOr(s64, Copies[0], Copies[1]); 225 // Try to match OR. 226 match = mi_match(MIBOr->getOperand(0).getReg(), MRI, 227 m_GOr(m_Reg(Src0), m_Reg(Src1))); 228 ASSERT_TRUE(match); 229 ASSERT_EQ(Src0, Copies[0]); 230 ASSERT_EQ(Src1, Copies[1]); 231 } 232 233 TEST(PatternMatchInstr, MatchFPUnaryOp) { 234 LLVMContext Context; 235 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 236 if (!TM) 237 return; 238 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 239 MachineFunction *MF = 240 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 241 SmallVector<unsigned, 4> Copies; 242 collectCopies(Copies, MF); 243 MachineBasicBlock *EntryMBB = &*MF->begin(); 244 MachineIRBuilder B(*MF); 245 MachineRegisterInfo &MRI = MF->getRegInfo(); 246 B.setInsertPt(*EntryMBB, EntryMBB->end()); 247 248 // Truncate s64 to s32. 249 LLT s32 = LLT::scalar(32); 250 auto Copy0s32 = B.buildFPTrunc(s32, Copies[0]); 251 252 // Match G_FABS. 253 auto MIBFabs = B.buildInstr(TargetOpcode::G_FABS, s32, Copy0s32); 254 bool match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg())); 255 ASSERT_TRUE(match); 256 unsigned Src; 257 match = mi_match(MIBFabs->getOperand(0).getReg(), MRI, m_GFabs(m_Reg(Src))); 258 ASSERT_TRUE(match); 259 ASSERT_EQ(Src, Copy0s32->getOperand(0).getReg()); 260 261 // Build and match FConstant. 262 auto MIBFCst = B.buildFConstant(s32, .5); 263 const ConstantFP *TmpFP{}; 264 match = mi_match(MIBFCst->getOperand(0).getReg(), MRI, m_GFCst(TmpFP)); 265 ASSERT_TRUE(match); 266 ASSERT_TRUE(TmpFP); 267 APFloat APF((float).5); 268 auto *CFP = ConstantFP::get(Context, APF); 269 ASSERT_EQ(CFP, TmpFP); 270 271 // Build double float. 272 LLT s64 = LLT::scalar(64); 273 auto MIBFCst64 = B.buildFConstant(s64, .5); 274 const ConstantFP *TmpFP64{}; 275 match = mi_match(MIBFCst64->getOperand(0).getReg(), MRI, m_GFCst(TmpFP64)); 276 ASSERT_TRUE(match); 277 ASSERT_TRUE(TmpFP64); 278 APFloat APF64(.5); 279 auto CFP64 = ConstantFP::get(Context, APF64); 280 ASSERT_EQ(CFP64, TmpFP64); 281 ASSERT_NE(TmpFP64, TmpFP); 282 283 // Build half float. 284 LLT s16 = LLT::scalar(16); 285 auto MIBFCst16 = B.buildFConstant(s16, .5); 286 const ConstantFP *TmpFP16{}; 287 match = mi_match(MIBFCst16->getOperand(0).getReg(), MRI, m_GFCst(TmpFP16)); 288 ASSERT_TRUE(match); 289 ASSERT_TRUE(TmpFP16); 290 bool Ignored; 291 APFloat APF16(.5); 292 APF16.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); 293 auto CFP16 = ConstantFP::get(Context, APF16); 294 ASSERT_EQ(TmpFP16, CFP16); 295 ASSERT_NE(TmpFP16, TmpFP); 296 } 297 298 TEST(PatternMatchInstr, MatchExtendsTrunc) { 299 LLVMContext Context; 300 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 301 if (!TM) 302 return; 303 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 304 MachineFunction *MF = 305 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 306 SmallVector<unsigned, 4> Copies; 307 collectCopies(Copies, MF); 308 MachineBasicBlock *EntryMBB = &*MF->begin(); 309 MachineIRBuilder B(*MF); 310 MachineRegisterInfo &MRI = MF->getRegInfo(); 311 B.setInsertPt(*EntryMBB, EntryMBB->end()); 312 LLT s64 = LLT::scalar(64); 313 LLT s32 = LLT::scalar(32); 314 315 auto MIBTrunc = B.buildTrunc(s32, Copies[0]); 316 auto MIBAExt = B.buildAnyExt(s64, MIBTrunc); 317 auto MIBZExt = B.buildZExt(s64, MIBTrunc); 318 auto MIBSExt = B.buildSExt(s64, MIBTrunc); 319 unsigned Src0; 320 bool match = 321 mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0))); 322 ASSERT_TRUE(match); 323 ASSERT_EQ(Src0, Copies[0]); 324 match = 325 mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0))); 326 ASSERT_TRUE(match); 327 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 328 329 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0))); 330 ASSERT_TRUE(match); 331 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 332 333 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0))); 334 ASSERT_TRUE(match); 335 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 336 337 // Match ext(trunc src) 338 match = mi_match(MIBAExt->getOperand(0).getReg(), MRI, 339 m_GAnyExt(m_GTrunc(m_Reg(Src0)))); 340 ASSERT_TRUE(match); 341 ASSERT_EQ(Src0, Copies[0]); 342 343 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, 344 m_GSExt(m_GTrunc(m_Reg(Src0)))); 345 ASSERT_TRUE(match); 346 ASSERT_EQ(Src0, Copies[0]); 347 348 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, 349 m_GZExt(m_GTrunc(m_Reg(Src0)))); 350 ASSERT_TRUE(match); 351 ASSERT_EQ(Src0, Copies[0]); 352 } 353 354 TEST(PatternMatchInstr, MatchSpecificType) { 355 LLVMContext Context; 356 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 357 if (!TM) 358 return; 359 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 360 MachineFunction *MF = 361 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 362 SmallVector<unsigned, 4> Copies; 363 collectCopies(Copies, MF); 364 MachineBasicBlock *EntryMBB = &*MF->begin(); 365 MachineIRBuilder B(*MF); 366 MachineRegisterInfo &MRI = MF->getRegInfo(); 367 B.setInsertPt(*EntryMBB, EntryMBB->end()); 368 369 // Try to match a 64bit add. 370 LLT s64 = LLT::scalar(64); 371 LLT s32 = LLT::scalar(32); 372 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 373 ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 374 m_GAdd(m_SpecificType(s32), m_Reg()))); 375 ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 376 m_GAdd(m_SpecificType(s64), m_Reg()))); 377 378 // Try to match the destination type of a bitcast. 379 LLT v2s32 = LLT::vector(2, 32); 380 auto MIBCast = B.buildCast(v2s32, Copies[0]); 381 ASSERT_TRUE( 382 mi_match(MIBCast->getOperand(0).getReg(), MRI, m_GBitcast(m_Reg()))); 383 ASSERT_TRUE( 384 mi_match(MIBCast->getOperand(0).getReg(), MRI, m_SpecificType(v2s32))); 385 ASSERT_TRUE( 386 mi_match(MIBCast->getOperand(1).getReg(), MRI, m_SpecificType(s64))); 387 388 // Build a PTRToInt and INTTOPTR and match and test them. 389 LLT PtrTy = LLT::pointer(0, 64); 390 auto MIBIntToPtr = B.buildCast(PtrTy, Copies[0]); 391 auto MIBPtrToInt = B.buildCast(s64, MIBIntToPtr); 392 unsigned Src0; 393 394 // match the ptrtoint(inttoptr reg) 395 bool match = mi_match(MIBPtrToInt->getOperand(0).getReg(), MRI, 396 m_GPtrToInt(m_GIntToPtr(m_Reg(Src0)))); 397 ASSERT_TRUE(match); 398 ASSERT_EQ(Src0, Copies[0]); 399 } 400 401 TEST(PatternMatchInstr, MatchCombinators) { 402 LLVMContext Context; 403 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 404 if (!TM) 405 return; 406 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 407 MachineFunction *MF = 408 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 409 SmallVector<unsigned, 4> Copies; 410 collectCopies(Copies, MF); 411 MachineBasicBlock *EntryMBB = &*MF->begin(); 412 MachineIRBuilder B(*MF); 413 MachineRegisterInfo &MRI = MF->getRegInfo(); 414 B.setInsertPt(*EntryMBB, EntryMBB->end()); 415 LLT s64 = LLT::scalar(64); 416 LLT s32 = LLT::scalar(32); 417 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 418 unsigned Src0, Src1; 419 bool match = 420 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 421 m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 422 ASSERT_TRUE(match); 423 ASSERT_EQ(Src0, Copies[0]); 424 ASSERT_EQ(Src1, Copies[1]); 425 // Check for s32 (which should fail). 426 match = 427 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 428 m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 429 ASSERT_FALSE(match); 430 match = 431 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 432 m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 433 ASSERT_TRUE(match); 434 ASSERT_EQ(Src0, Copies[0]); 435 ASSERT_EQ(Src1, Copies[1]); 436 437 // Match a case where none of the predicates hold true. 438 match = mi_match( 439 MIBAdd->getOperand(0).getReg(), MRI, 440 m_any_of(m_SpecificType(LLT::scalar(16)), m_GSub(m_Reg(), m_Reg()))); 441 ASSERT_FALSE(match); 442 } 443 } // namespace 444 445 int main(int argc, char **argv) { 446 ::testing::InitGoogleTest(&argc, argv); 447 initLLVM(); 448 return RUN_ALL_TESTS(); 449 } 450