1 //===- RegAllocScore.cpp - evaluate regalloc policy quality ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// Calculate a measure of the register allocation policy quality. This is used
9 /// to construct a reward for the training of the ML-driven allocation policy.
10 /// Currently, the score is the sum of the machine basic block frequency-weighed
11 /// number of loads, stores, copies, and remat instructions, each factored with
12 /// a relative weight.
13 //===----------------------------------------------------------------------===//
14 
15 #include "RegAllocScore.h"
16 #include "llvm/ADT/STLForwardCompat.h"
17 #include "llvm/ADT/ilist_iterator.h"
18 #include "llvm/CodeGen/MachineBasicBlock.h"
19 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
20 #include "llvm/CodeGen/MachineFunction.h"
21 #include "llvm/CodeGen/MachineInstr.h"
22 #include "llvm/CodeGen/MachineInstrBundleIterator.h"
23 #include "llvm/CodeGen/TargetInstrInfo.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/Format.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include <cassert>
29 #include <cstdint>
30 #include <numeric>
31 #include <vector>
32 
33 using namespace llvm;
34 cl::opt<double> CopyWeight("regalloc-copy-weight", cl::init(0.2), cl::Hidden);
35 cl::opt<double> LoadWeight("regalloc-load-weight", cl::init(4.0), cl::Hidden);
36 cl::opt<double> StoreWeight("regalloc-store-weight", cl::init(1.0), cl::Hidden);
37 cl::opt<double> CheapRematWeight("regalloc-cheap-remat-weight", cl::init(0.2),
38                                  cl::Hidden);
39 cl::opt<double> ExpensiveRematWeight("regalloc-expensive-remat-weight",
40                                      cl::init(1.0), cl::Hidden);
41 #define DEBUG_TYPE "regalloc-score"
42 
43 RegAllocScore &RegAllocScore::operator+=(const RegAllocScore &Other) {
44   CopyCounts += Other.copyCounts();
45   LoadCounts += Other.loadCounts();
46   StoreCounts += Other.storeCounts();
47   LoadStoreCounts += Other.loadStoreCounts();
48   CheapRematCounts += Other.cheapRematCounts();
49   ExpensiveRematCounts += Other.expensiveRematCounts();
50   return *this;
51 }
52 
53 bool RegAllocScore::operator==(const RegAllocScore &Other) const {
54   return copyCounts() == Other.copyCounts() &&
55          loadCounts() == Other.loadCounts() &&
56          storeCounts() == Other.storeCounts() &&
57          loadStoreCounts() == Other.loadStoreCounts() &&
58          cheapRematCounts() == Other.cheapRematCounts() &&
59          expensiveRematCounts() == Other.expensiveRematCounts();
60 }
61 
62 bool RegAllocScore::operator!=(const RegAllocScore &Other) const {
63   return !(*this == Other);
64 }
65 
66 double RegAllocScore::getScore() const {
67   double Ret = 0.0;
68   Ret += CopyWeight * copyCounts();
69   Ret += LoadWeight * loadCounts();
70   Ret += StoreWeight * storeCounts();
71   Ret += (LoadWeight + StoreWeight) * loadStoreCounts();
72   Ret += CheapRematWeight * cheapRematCounts();
73   Ret += ExpensiveRematWeight * expensiveRematCounts();
74 
75   return Ret;
76 }
77 
78 RegAllocScore
79 llvm::calculateRegAllocScore(const MachineFunction &MF,
80                              const MachineBlockFrequencyInfo &MBFI,
81                              AAResults &AAResults) {
82   return calculateRegAllocScore(
83       MF,
84       [&](const MachineBasicBlock &MBB) {
85         return MBFI.getBlockFreqRelativeToEntryBlock(&MBB);
86       },
87       [&](const MachineInstr &MI) {
88         return MF.getSubtarget().getInstrInfo()->isTriviallyReMaterializable(
89             MI, &AAResults);
90       });
91 }
92 
93 RegAllocScore llvm::calculateRegAllocScore(
94     const MachineFunction &MF,
95     llvm::function_ref<double(const MachineBasicBlock &)> GetBBFreq,
96     llvm::function_ref<bool(const MachineInstr &)>
97         IsTriviallyRematerializable) {
98   RegAllocScore Total;
99 
100   for (const MachineBasicBlock &MBB : MF) {
101     double BlockFreqRelativeToEntrypoint = GetBBFreq(MBB);
102     RegAllocScore MBBScore;
103 
104     for (const MachineInstr &MI : MBB) {
105       if (MI.isDebugInstr() || MI.isKill() || MI.isInlineAsm()) {
106         continue;
107       }
108       if (MI.isCopy()) {
109         MBBScore.onCopy(BlockFreqRelativeToEntrypoint);
110       } else if (IsTriviallyRematerializable(MI)) {
111         if (MI.getDesc().isAsCheapAsAMove()) {
112           MBBScore.onCheapRemat(BlockFreqRelativeToEntrypoint);
113         } else {
114           MBBScore.onExpensiveRemat(BlockFreqRelativeToEntrypoint);
115         }
116       } else if (MI.mayLoad() && MI.mayStore()) {
117         MBBScore.onLoadStore(BlockFreqRelativeToEntrypoint);
118       } else if (MI.mayLoad()) {
119         MBBScore.onLoad(BlockFreqRelativeToEntrypoint);
120       } else if (MI.mayStore()) {
121         MBBScore.onStore(BlockFreqRelativeToEntrypoint);
122       }
123     }
124     Total += MBBScore;
125   }
126   return Total;
127 }
128