1"""Code generator for Code Completion Model Inference.
2
3Tool runs on the Decision Forest model defined in {model} directory.
4It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp
5The generated files defines the Example class named {cpp_class} having all the features as class members.
6The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate.
7"""
8
9import argparse
10import json
11import struct
12
13
14class CppClass:
15    """Holds class name and names of the enclosing namespaces."""
16
17    def __init__(self, cpp_class):
18        ns_and_class = cpp_class.split("::")
19        self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0]
20        self.name = ns_and_class[-1]
21        if len(self.name) == 0:
22            raise ValueError("Empty class name.")
23
24    def ns_begin(self):
25        """Returns snippet for opening namespace declarations."""
26        open_ns = ["namespace %s {" % ns for ns in self.ns]
27        return "\n".join(open_ns)
28
29    def ns_end(self):
30        """Returns snippet for closing namespace declarations."""
31        close_ns = [
32            "} // namespace %s" % ns for ns in reversed(self.ns)]
33        return "\n".join(close_ns)
34
35
36def header_guard(filename):
37    '''Returns the header guard for the generated header.'''
38    return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper()
39
40
41def boost_node(n, label, next_label):
42    """Returns code snippet for a leaf/boost node."""
43    return "%s: return %sf;" % (label, n['score'])
44
45
46def if_greater_node(n, label, next_label):
47    """Returns code snippet for a if_greater node.
48    Jumps to true_label if the Example feature (NUMBER) is greater than the threshold.
49    Comparing integers is much faster than comparing floats. Assuming floating points
50    are represented as IEEE 754, it order-encodes the floats to integers before comparing them.
51    Control falls through if condition is evaluated to false."""
52    threshold = n["threshold"]
53    return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % (
54        label, n['feature'], order_encode(threshold), threshold, next_label)
55
56
57def if_member_node(n, label, next_label):
58    """Returns code snippet for a if_member node.
59    Jumps to true_label if the Example feature (ENUM) is present in the set of enum values
60    described in the node.
61    Control falls through if condition is evaluated to false."""
62    members = '|'.join([
63        "BIT(%s_type::%s)" % (n['feature'], member)
64        for member in n["set"]
65    ])
66    return "%s: if (E.get%s() & (%s)) goto %s;" % (
67        label, n['feature'], members, next_label)
68
69
70def node(n, label, next_label):
71    """Returns code snippet for the node."""
72    return {
73        'boost': boost_node,
74        'if_greater': if_greater_node,
75        'if_member': if_member_node,
76    }[n['operation']](n, label, next_label)
77
78
79def tree(t, tree_num, node_num):
80    """Returns code for inferencing a Decision Tree.
81    Also returns the size of the decision tree.
82
83    A tree starts with its label `t{tree#}`.
84    A node of the tree starts with label `t{tree#}_n{node#}`.
85
86    The tree contains two types of node: Conditional node and Leaf node.
87    -   Conditional node evaluates a condition. If true, it jumps to the true node/child.
88        Code is generated using pre-order traversal of the tree considering
89        false node as the first child. Therefore the false node is always the
90        immediately next label.
91    -   Leaf node adds the value to the score and jumps to the next tree.
92    """
93    label = "t%d_n%d" % (tree_num, node_num)
94    code = []
95
96    if t["operation"] == "boost":
97        code.append(node(t, label=label, next_label="t%d" % (tree_num + 1)))
98        return code, 1
99
100    false_code, false_size = tree(
101        t['else'], tree_num=tree_num, node_num=node_num+1)
102
103    true_node_num = node_num+false_size+1
104    true_label = "t%d_n%d" % (tree_num, true_node_num)
105
106    true_code, true_size = tree(
107        t['then'], tree_num=tree_num, node_num=true_node_num)
108
109    code.append(node(t, label=label, next_label=true_label))
110
111    return code+false_code+true_code, 1+false_size+true_size
112
113
114def gen_header_code(features_json, cpp_class, filename):
115    """Returns code for header declaring the inference runtime.
116
117    Declares the Example class named {cpp_class} inside relevant namespaces.
118    The Example class contains all the features as class members. This
119    class can be used to represent a code completion candidate.
120    Provides `float Evaluate()` function which can be used to score the Example.
121    """
122    setters = []
123    getters = []
124    for f in features_json:
125        feature = f["name"]
126
127        if f["kind"] == "NUMBER":
128            # Floats are order-encoded to integers for faster comparison.
129            setters.append(
130                "void set%s(float V) { %s = OrderEncode(V); }" % (
131                    feature, feature))
132        elif f["kind"] == "ENUM":
133            setters.append(
134                "void set%s(unsigned V) { %s = 1LL << V; }" % (feature, feature))
135        else:
136            raise ValueError("Unhandled feature type.", f["kind"])
137
138    # Class members represent all the features of the Example.
139    class_members = [
140        "uint%d_t %s = 0;"
141        % (64 if f["kind"] == "ENUM" else 32, f['name'])
142        for f in features_json
143    ]
144    getters = [
145        "LLVM_ATTRIBUTE_ALWAYS_INLINE uint%d_t get%s() const { return %s; }"
146        % (64 if f["kind"] == "ENUM" else 32, f['name'], f['name'])
147        for f in features_json
148    ]
149    nline = "\n  "
150    guard = header_guard(filename)
151    return """#ifndef %s
152#define %s
153#include <cstdint>
154#include "llvm/Support/Compiler.h"
155
156%s
157class %s {
158public:
159  // Setters.
160  %s
161
162  // Getters.
163  %s
164
165private:
166  %s
167
168  // Produces an integer that sorts in the same order as F.
169  // That is: a < b <==> orderEncode(a) < orderEncode(b).
170  static uint32_t OrderEncode(float F);
171};
172
173float Evaluate(const %s&);
174%s
175#endif // %s
176""" % (guard, guard, cpp_class.ns_begin(), cpp_class.name,
177        nline.join(setters),
178        nline.join(getters),
179        nline.join(class_members),
180        cpp_class.name, cpp_class.ns_end(), guard)
181
182
183def order_encode(v):
184    i = struct.unpack('<I', struct.pack('<f', v))[0]
185    TopBit = 1 << 31
186    # IEEE 754 floats compare like sign-magnitude integers.
187    if (i & TopBit):  # Negative float
188        return (1 << 32) - i  # low half of integers, order reversed.
189    return TopBit + i  # top half of integers
190
191
192def evaluate_func(forest_json, cpp_class):
193    """Generates evaluation functions for each tree and combines them in
194    `float Evaluate(const {Example}&)` function. This function can be
195    used to score an Example."""
196
197    code = ""
198
199    # Generate evaluation function of each tree.
200    code += "namespace {\n"
201    tree_num = 0
202    for tree_json in forest_json:
203        code += "LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % (tree_num, cpp_class.name)
204        code += "  " + \
205            "\n  ".join(
206                tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n"
207        code += "}\n\n"
208        tree_num += 1
209    code += "} // namespace\n\n"
210
211    # Combine the scores of all trees in the final function.
212    # MSAN will timeout if these functions are inlined.
213    code += "float Evaluate(const %s& E) {\n" % cpp_class.name
214    code += "  float Score = 0;\n"
215    for tree_num in range(len(forest_json)):
216        code += "  Score += EvaluateTree%d(E);\n" % tree_num
217    code += "  return Score;\n"
218    code += "}\n"
219
220    return code
221
222
223def gen_cpp_code(forest_json, features_json, filename, cpp_class):
224    """Generates code for the .cpp file."""
225    # Headers
226    # Required by OrderEncode(float F).
227    angled_include = [
228        '#include <%s>' % h
229        for h in ["cstring", "limits"]
230    ]
231
232    # Include generated header.
233    qouted_headers = {filename + '.h', 'llvm/ADT/bit.h'}
234    # Headers required by ENUM features used by the model.
235    qouted_headers |= {f["header"]
236                       for f in features_json if f["kind"] == "ENUM"}
237    quoted_include = ['#include "%s"' % h for h in sorted(qouted_headers)]
238
239    # using-decl for ENUM features.
240    using_decls = "\n".join("using %s_type = %s;" % (
241        feature['name'], feature['type'])
242        for feature in features_json
243        if feature["kind"] == "ENUM")
244    nl = "\n"
245    return """%s
246
247%s
248
249#define BIT(X) (1LL << X)
250
251%s
252
253%s
254
255uint32_t %s::OrderEncode(float F) {
256  static_assert(std::numeric_limits<float>::is_iec559, "");
257  constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1);
258
259  // Get the bits of the float. Endianness is the same as for integers.
260  uint32_t U = llvm::bit_cast<uint32_t>(F);
261  std::memcpy(&U, &F, sizeof(U));
262  // IEEE 754 floats compare like sign-magnitude integers.
263  if (U & TopBit)    // Negative float.
264    return 0 - U;    // Map onto the low half of integers, order reversed.
265  return U + TopBit; // Positive floats map onto the high half of integers.
266}
267
268%s
269%s
270""" % (nl.join(angled_include), nl.join(quoted_include), cpp_class.ns_begin(),
271       using_decls, cpp_class.name, evaluate_func(forest_json, cpp_class),
272       cpp_class.ns_end())
273
274
275def main():
276    parser = argparse.ArgumentParser('DecisionForestCodegen')
277    parser.add_argument('--filename', help='output file name.')
278    parser.add_argument('--output_dir', help='output directory.')
279    parser.add_argument('--model', help='path to model directory.')
280    parser.add_argument(
281        '--cpp_class',
282        help='The name of the class (which may be a namespace-qualified) created in generated header.'
283    )
284    ns = parser.parse_args()
285
286    output_dir = ns.output_dir
287    filename = ns.filename
288    header_file = "%s/%s.h" % (output_dir, filename)
289    cpp_file = "%s/%s.cpp" % (output_dir, filename)
290    cpp_class = CppClass(cpp_class=ns.cpp_class)
291
292    model_file = "%s/forest.json" % ns.model
293    features_file = "%s/features.json" % ns.model
294
295    with open(features_file) as f:
296        features_json = json.load(f)
297
298    with open(model_file) as m:
299        forest_json = json.load(m)
300
301    with open(cpp_file, 'w+t') as output_cc:
302        output_cc.write(
303            gen_cpp_code(forest_json=forest_json,
304                         features_json=features_json,
305                         filename=filename,
306                         cpp_class=cpp_class))
307
308    with open(header_file, 'w+t') as output_h:
309        output_h.write(gen_header_code(
310            features_json=features_json,
311            cpp_class=cpp_class,
312            filename=filename))
313
314
315if __name__ == '__main__':
316    main()
317