1 //===- ReduceRegisterMasks.cpp - Specialized Delta Pass -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a function which calls the Generic Delta pass in order
10 // to reduce custom register masks from the MachineFunction.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "ReduceRegisterMasks.h"
15 #include "llvm/CodeGen/MachineFunction.h"
16 #include "llvm/CodeGen/MachineRegisterInfo.h"
17 
18 using namespace llvm;
19 
reduceMasksInFunction(Oracle & O,MachineFunction & MF)20 static void reduceMasksInFunction(Oracle &O, MachineFunction &MF) {
21   DenseSet<const uint32_t *> ConstRegisterMasks;
22   const auto *TRI = MF.getSubtarget().getRegisterInfo();
23 
24   // Track predefined/named regmasks which we ignore.
25   const unsigned NumRegs = TRI->getNumRegs();
26   for (const uint32_t *Mask : TRI->getRegMasks())
27     ConstRegisterMasks.insert(Mask);
28 
29   for (MachineBasicBlock &MBB : MF) {
30     for (MachineInstr &MI : MBB) {
31       for (MachineOperand &MO : MI.operands()) {
32         if (!MO.isRegMask())
33           continue;
34 
35         const uint32_t *OldRegMask = MO.getRegMask();
36         // We're only reducing custom reg masks.
37         if (ConstRegisterMasks.count(OldRegMask))
38           continue;
39         unsigned RegMaskSize =
40             MachineOperand::getRegMaskSize(TRI->getNumRegs());
41         std::vector<uint32_t> NewMask(RegMaskSize);
42 
43         bool MadeChange = false;
44         for (unsigned I = 0; I != NumRegs; ++I) {
45           if (OldRegMask[I / 32] & (1u << (I % 32))) {
46             if (O.shouldKeep())
47               NewMask[I / 32] |= 1u << (I % 32);
48           } else
49             MadeChange = true;
50         }
51 
52         if (MadeChange) {
53           uint32_t *UpdatedMask = MF.allocateRegMask();
54           std::memcpy(UpdatedMask, NewMask.data(),
55                       RegMaskSize * sizeof(*OldRegMask));
56           MO.setRegMask(UpdatedMask);
57         }
58       }
59     }
60   }
61 }
62 
reduceMasksInModule(Oracle & O,ReducerWorkItem & WorkItem)63 static void reduceMasksInModule(Oracle &O, ReducerWorkItem &WorkItem) {
64   for (const Function &F : WorkItem.getModule()) {
65     if (auto *MF = WorkItem.MMI->getMachineFunction(F))
66       reduceMasksInFunction(O, *MF);
67   }
68 }
69 
reduceRegisterMasksMIRDeltaPass(TestRunner & Test)70 void llvm::reduceRegisterMasksMIRDeltaPass(TestRunner &Test) {
71   outs() << "*** Reducing register masks...\n";
72   runDeltaPass(Test, reduceMasksInModule);
73 }
74