1 //===- PatternMatchTest.cpp -----------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 11 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 12 #include "llvm/CodeGen/GlobalISel/Utils.h" 13 #include "llvm/CodeGen/MIRParser/MIRParser.h" 14 #include "llvm/CodeGen/MachineFunction.h" 15 #include "llvm/CodeGen/MachineModuleInfo.h" 16 #include "llvm/CodeGen/TargetFrameLowering.h" 17 #include "llvm/CodeGen/TargetInstrInfo.h" 18 #include "llvm/CodeGen/TargetLowering.h" 19 #include "llvm/CodeGen/TargetSubtargetInfo.h" 20 #include "llvm/Support/SourceMgr.h" 21 #include "llvm/Support/TargetRegistry.h" 22 #include "llvm/Support/TargetSelect.h" 23 #include "llvm/Target/TargetMachine.h" 24 #include "llvm/Target/TargetOptions.h" 25 #include "gtest/gtest.h" 26 27 using namespace llvm; 28 using namespace MIPatternMatch; 29 30 namespace { 31 32 void initLLVM() { 33 InitializeAllTargets(); 34 InitializeAllTargetMCs(); 35 InitializeAllAsmPrinters(); 36 InitializeAllAsmParsers(); 37 38 PassRegistry *Registry = PassRegistry::getPassRegistry(); 39 initializeCore(*Registry); 40 initializeCodeGen(*Registry); 41 } 42 43 /// Create a TargetMachine. As we lack a dedicated always available target for 44 /// unittests, we go for "AArch64". 45 std::unique_ptr<TargetMachine> createTargetMachine() { 46 Triple TargetTriple("aarch64--"); 47 std::string Error; 48 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); 49 if (!T) 50 return nullptr; 51 52 TargetOptions Options; 53 return std::unique_ptr<TargetMachine>(T->createTargetMachine( 54 "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive)); 55 } 56 57 std::unique_ptr<Module> parseMIR(LLVMContext &Context, 58 std::unique_ptr<MIRParser> &MIR, 59 const TargetMachine &TM, StringRef MIRCode, 60 const char *FuncName, MachineModuleInfo &MMI) { 61 SMDiagnostic Diagnostic; 62 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode); 63 MIR = createMIRParser(std::move(MBuffer), Context); 64 if (!MIR) 65 return nullptr; 66 67 std::unique_ptr<Module> M = MIR->parseIRModule(); 68 if (!M) 69 return nullptr; 70 71 M->setDataLayout(TM.createDataLayout()); 72 73 if (MIR->parseMachineFunctions(*M, MMI)) 74 return nullptr; 75 76 return M; 77 } 78 79 std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>> 80 createDummyModule(LLVMContext &Context, const TargetMachine &TM, 81 StringRef MIRFunc) { 82 SmallString<512> S; 83 StringRef MIRString = (Twine(R"MIR( 84 --- 85 ... 86 name: func 87 registers: 88 - { id: 0, class: _ } 89 - { id: 1, class: _ } 90 - { id: 2, class: _ } 91 - { id: 3, class: _ } 92 body: | 93 bb.1: 94 %0(s64) = COPY $x0 95 %1(s64) = COPY $x1 96 %2(s64) = COPY $x2 97 )MIR") + Twine(MIRFunc) + Twine("...\n")) 98 .toNullTerminatedStringRef(S); 99 std::unique_ptr<MIRParser> MIR; 100 auto MMI = make_unique<MachineModuleInfo>(&TM); 101 std::unique_ptr<Module> M = 102 parseMIR(Context, MIR, TM, MIRString, "func", *MMI); 103 return make_pair(std::move(M), std::move(MMI)); 104 } 105 106 static MachineFunction *getMFFromMMI(const Module *M, 107 const MachineModuleInfo *MMI) { 108 Function *F = M->getFunction("func"); 109 auto *MF = MMI->getMachineFunction(*F); 110 return MF; 111 } 112 113 static void collectCopies(SmallVectorImpl<unsigned> &Copies, 114 MachineFunction *MF) { 115 for (auto &MBB : *MF) 116 for (MachineInstr &MI : MBB) { 117 if (MI.getOpcode() == TargetOpcode::COPY) 118 Copies.push_back(MI.getOperand(0).getReg()); 119 } 120 } 121 122 TEST(PatternMatchInstr, MatchIntConstant) { 123 LLVMContext Context; 124 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 125 if (!TM) 126 return; 127 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 128 MachineFunction *MF = 129 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 130 SmallVector<unsigned, 4> Copies; 131 collectCopies(Copies, MF); 132 MachineBasicBlock *EntryMBB = &*MF->begin(); 133 MachineIRBuilder B(*MF); 134 MachineRegisterInfo &MRI = MF->getRegInfo(); 135 B.setInsertPt(*EntryMBB, EntryMBB->end()); 136 auto MIBCst = B.buildConstant(LLT::scalar(64), 42); 137 uint64_t Cst; 138 bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst)); 139 ASSERT_TRUE(match); 140 ASSERT_EQ(Cst, (uint64_t)42); 141 } 142 143 TEST(PatternMatchInstr, MatchBinaryOp) { 144 LLVMContext Context; 145 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 146 if (!TM) 147 return; 148 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 149 MachineFunction *MF = 150 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 151 SmallVector<unsigned, 4> Copies; 152 collectCopies(Copies, MF); 153 MachineBasicBlock *EntryMBB = &*MF->begin(); 154 MachineIRBuilder B(*MF); 155 MachineRegisterInfo &MRI = MF->getRegInfo(); 156 B.setInsertPt(*EntryMBB, EntryMBB->end()); 157 LLT s64 = LLT::scalar(64); 158 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 159 // Test case for no bind. 160 bool match = 161 mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg())); 162 ASSERT_TRUE(match); 163 unsigned Src0, Src1, Src2; 164 match = mi_match(MIBAdd->getOperand(0).getReg(), MRI, 165 m_GAdd(m_Reg(Src0), m_Reg(Src1))); 166 ASSERT_TRUE(match); 167 ASSERT_EQ(Src0, Copies[0]); 168 ASSERT_EQ(Src1, Copies[1]); 169 170 // Build MUL(ADD %0, %1), %2 171 auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]); 172 173 // Try to match MUL. 174 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 175 m_GMul(m_Reg(Src0), m_Reg(Src1))); 176 ASSERT_TRUE(match); 177 ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg()); 178 ASSERT_EQ(Src1, Copies[2]); 179 180 // Try to match MUL(ADD) 181 match = mi_match(MIBMul->getOperand(0).getReg(), MRI, 182 m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2))); 183 ASSERT_TRUE(match); 184 ASSERT_EQ(Src0, Copies[0]); 185 ASSERT_EQ(Src1, Copies[1]); 186 ASSERT_EQ(Src2, Copies[2]); 187 188 // Test Commutativity. 189 auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42)); 190 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate 191 // commutativity. 192 uint64_t Cst; 193 match = mi_match(MIBMul2->getOperand(0).getReg(), MRI, 194 m_GMul(m_ICst(Cst), m_Reg(Src0))); 195 ASSERT_TRUE(match); 196 ASSERT_EQ(Cst, (uint64_t)42); 197 ASSERT_EQ(Src0, Copies[0]); 198 199 // Make sure commutative doesn't work with something like SUB. 200 auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42)); 201 match = mi_match(MIBSub->getOperand(0).getReg(), MRI, 202 m_GSub(m_ICst(Cst), m_Reg(Src0))); 203 ASSERT_FALSE(match); 204 205 auto MIBFMul = B.buildInstr(TargetOpcode::G_FMUL, s64, Copies[0], 206 B.buildConstant(s64, 42)); 207 // Match and test commutativity for FMUL. 208 match = mi_match(MIBFMul->getOperand(0).getReg(), MRI, 209 m_GFMul(m_ICst(Cst), m_Reg(Src0))); 210 ASSERT_TRUE(match); 211 ASSERT_EQ(Cst, (uint64_t)42); 212 ASSERT_EQ(Src0, Copies[0]); 213 } 214 215 TEST(PatternMatchInstr, MatchExtendsTrunc) { 216 LLVMContext Context; 217 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 218 if (!TM) 219 return; 220 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 221 MachineFunction *MF = 222 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 223 SmallVector<unsigned, 4> Copies; 224 collectCopies(Copies, MF); 225 MachineBasicBlock *EntryMBB = &*MF->begin(); 226 MachineIRBuilder B(*MF); 227 MachineRegisterInfo &MRI = MF->getRegInfo(); 228 B.setInsertPt(*EntryMBB, EntryMBB->end()); 229 LLT s64 = LLT::scalar(64); 230 LLT s32 = LLT::scalar(32); 231 232 auto MIBTrunc = B.buildTrunc(s32, Copies[0]); 233 auto MIBAExt = B.buildAnyExt(s64, MIBTrunc); 234 auto MIBZExt = B.buildZExt(s64, MIBTrunc); 235 auto MIBSExt = B.buildSExt(s64, MIBTrunc); 236 unsigned Src0; 237 bool match = 238 mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0))); 239 ASSERT_TRUE(match); 240 ASSERT_EQ(Src0, Copies[0]); 241 match = 242 mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0))); 243 ASSERT_TRUE(match); 244 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 245 246 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0))); 247 ASSERT_TRUE(match); 248 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 249 250 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0))); 251 ASSERT_TRUE(match); 252 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg()); 253 254 // Match ext(trunc src) 255 match = mi_match(MIBAExt->getOperand(0).getReg(), MRI, 256 m_GAnyExt(m_GTrunc(m_Reg(Src0)))); 257 ASSERT_TRUE(match); 258 ASSERT_EQ(Src0, Copies[0]); 259 260 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, 261 m_GSExt(m_GTrunc(m_Reg(Src0)))); 262 ASSERT_TRUE(match); 263 ASSERT_EQ(Src0, Copies[0]); 264 265 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, 266 m_GZExt(m_GTrunc(m_Reg(Src0)))); 267 ASSERT_TRUE(match); 268 ASSERT_EQ(Src0, Copies[0]); 269 } 270 271 TEST(PatternMatchInstr, MatchSpecificType) { 272 LLVMContext Context; 273 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 274 if (!TM) 275 return; 276 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 277 MachineFunction *MF = 278 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 279 SmallVector<unsigned, 4> Copies; 280 collectCopies(Copies, MF); 281 MachineBasicBlock *EntryMBB = &*MF->begin(); 282 MachineIRBuilder B(*MF); 283 MachineRegisterInfo &MRI = MF->getRegInfo(); 284 B.setInsertPt(*EntryMBB, EntryMBB->end()); 285 LLT s64 = LLT::scalar(64); 286 LLT s32 = LLT::scalar(32); 287 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 288 289 // Try to match a 64bit add. 290 ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 291 m_GAdd(m_SpecificType(s32), m_Reg()))); 292 ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI, 293 m_GAdd(m_SpecificType(s64), m_Reg()))); 294 } 295 296 TEST(PatternMatchInstr, MatchCombinators) { 297 LLVMContext Context; 298 std::unique_ptr<TargetMachine> TM = createTargetMachine(); 299 if (!TM) 300 return; 301 auto ModuleMMIPair = createDummyModule(Context, *TM, ""); 302 MachineFunction *MF = 303 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get()); 304 SmallVector<unsigned, 4> Copies; 305 collectCopies(Copies, MF); 306 MachineBasicBlock *EntryMBB = &*MF->begin(); 307 MachineIRBuilder B(*MF); 308 MachineRegisterInfo &MRI = MF->getRegInfo(); 309 B.setInsertPt(*EntryMBB, EntryMBB->end()); 310 LLT s64 = LLT::scalar(64); 311 LLT s32 = LLT::scalar(32); 312 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]); 313 unsigned Src0, Src1; 314 bool match = 315 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 316 m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 317 ASSERT_TRUE(match); 318 ASSERT_EQ(Src0, Copies[0]); 319 ASSERT_EQ(Src1, Copies[1]); 320 // Check for s32 (which should fail). 321 match = 322 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 323 m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 324 ASSERT_FALSE(match); 325 match = 326 mi_match(MIBAdd->getOperand(0).getReg(), MRI, 327 m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1)))); 328 ASSERT_TRUE(match); 329 ASSERT_EQ(Src0, Copies[0]); 330 ASSERT_EQ(Src1, Copies[1]); 331 } 332 } // namespace 333 334 int main(int argc, char **argv) { 335 ::testing::InitGoogleTest(&argc, argv); 336 initLLVM(); 337 return RUN_ALL_TESTS(); 338 } 339