1"""Code generator for Code Completion Model Inference. 2 3Tool runs on the Decision Forest model defined in {model} directory. 4It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp 5The generated files defines the Example class named {cpp_class} having all the features as class members. 6The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate. 7""" 8 9import argparse 10import json 11import struct 12 13 14class CppClass: 15 """Holds class name and names of the enclosing namespaces.""" 16 17 def __init__(self, cpp_class): 18 ns_and_class = cpp_class.split("::") 19 self.ns = [ns for ns in ns_and_class[0:-1] if len(ns) > 0] 20 self.name = ns_and_class[-1] 21 if len(self.name) == 0: 22 raise ValueError("Empty class name.") 23 24 def ns_begin(self): 25 """Returns snippet for opening namespace declarations.""" 26 open_ns = ["namespace %s {" % ns for ns in self.ns] 27 return "\n".join(open_ns) 28 29 def ns_end(self): 30 """Returns snippet for closing namespace declarations.""" 31 close_ns = [ 32 "} // namespace %s" % ns for ns in reversed(self.ns)] 33 return "\n".join(close_ns) 34 35 36def header_guard(filename): 37 '''Returns the header guard for the generated header.''' 38 return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper() 39 40 41def boost_node(n, label, next_label): 42 """Returns code snippet for a leaf/boost node.""" 43 return "%s: return %sf;" % (label, n['score']) 44 45 46def if_greater_node(n, label, next_label): 47 """Returns code snippet for a if_greater node. 48 Jumps to true_label if the Example feature (NUMBER) is greater than the threshold. 49 Comparing integers is much faster than comparing floats. Assuming floating points 50 are represented as IEEE 754, it order-encodes the floats to integers before comparing them. 51 Control falls through if condition is evaluated to false.""" 52 threshold = n["threshold"] 53 return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % ( 54 label, n['feature'], order_encode(threshold), threshold, next_label) 55 56 57def if_member_node(n, label, next_label): 58 """Returns code snippet for a if_member node. 59 Jumps to true_label if the Example feature (ENUM) is present in the set of enum values 60 described in the node. 61 Control falls through if condition is evaluated to false.""" 62 members = '|'.join([ 63 "BIT(%s_type::%s)" % (n['feature'], member) 64 for member in n["set"] 65 ]) 66 return "%s: if (E.get%s() & (%s)) goto %s;" % ( 67 label, n['feature'], members, next_label) 68 69 70def node(n, label, next_label): 71 """Returns code snippet for the node.""" 72 return { 73 'boost': boost_node, 74 'if_greater': if_greater_node, 75 'if_member': if_member_node, 76 }[n['operation']](n, label, next_label) 77 78 79def tree(t, tree_num, node_num): 80 """Returns code for inferencing a Decision Tree. 81 Also returns the size of the decision tree. 82 83 A tree starts with its label `t{tree#}`. 84 A node of the tree starts with label `t{tree#}_n{node#}`. 85 86 The tree contains two types of node: Conditional node and Leaf node. 87 - Conditional node evaluates a condition. If true, it jumps to the true node/child. 88 Code is generated using pre-order traversal of the tree considering 89 false node as the first child. Therefore the false node is always the 90 immediately next label. 91 - Leaf node adds the value to the score and jumps to the next tree. 92 """ 93 label = "t%d_n%d" % (tree_num, node_num) 94 code = [] 95 96 if t["operation"] == "boost": 97 code.append(node(t, label=label, next_label="t%d" % (tree_num + 1))) 98 return code, 1 99 100 false_code, false_size = tree( 101 t['else'], tree_num=tree_num, node_num=node_num+1) 102 103 true_node_num = node_num+false_size+1 104 true_label = "t%d_n%d" % (tree_num, true_node_num) 105 106 true_code, true_size = tree( 107 t['then'], tree_num=tree_num, node_num=true_node_num) 108 109 code.append(node(t, label=label, next_label=true_label)) 110 111 return code+false_code+true_code, 1+false_size+true_size 112 113 114def gen_header_code(features_json, cpp_class, filename): 115 """Returns code for header declaring the inference runtime. 116 117 Declares the Example class named {cpp_class} inside relevant namespaces. 118 The Example class contains all the features as class members. This 119 class can be used to represent a code completion candidate. 120 Provides `float Evaluate()` function which can be used to score the Example. 121 """ 122 setters = [] 123 getters = [] 124 for f in features_json: 125 feature = f["name"] 126 127 if f["kind"] == "NUMBER": 128 # Floats are order-encoded to integers for faster comparison. 129 setters.append( 130 "void set%s(float V) { %s = OrderEncode(V); }" % ( 131 feature, feature)) 132 elif f["kind"] == "ENUM": 133 setters.append( 134 "void set%s(unsigned V) { %s = 1LL << V; }" % (feature, feature)) 135 else: 136 raise ValueError("Unhandled feature type.", f["kind"]) 137 138 # Class members represent all the features of the Example. 139 class_members = [ 140 "uint%d_t %s = 0;" 141 % (64 if f["kind"] == "ENUM" else 32, f['name']) 142 for f in features_json 143 ] 144 getters = [ 145 "LLVM_ATTRIBUTE_ALWAYS_INLINE uint%d_t get%s() const { return %s; }" 146 % (64 if f["kind"] == "ENUM" else 32, f['name'], f['name']) 147 for f in features_json 148 ] 149 nline = "\n " 150 guard = header_guard(filename) 151 return """#ifndef %s 152#define %s 153#include <cstdint> 154#include "llvm/Support/Compiler.h" 155 156%s 157class %s { 158public: 159 // Setters. 160 %s 161 162 // Getters. 163 %s 164 165private: 166 %s 167 168 // Produces an integer that sorts in the same order as F. 169 // That is: a < b <==> orderEncode(a) < orderEncode(b). 170 static uint32_t OrderEncode(float F); 171}; 172 173float Evaluate(const %s&); 174%s 175#endif // %s 176""" % (guard, guard, cpp_class.ns_begin(), cpp_class.name, 177 nline.join(setters), 178 nline.join(getters), 179 nline.join(class_members), 180 cpp_class.name, cpp_class.ns_end(), guard) 181 182 183def order_encode(v): 184 i = struct.unpack('<I', struct.pack('<f', v))[0] 185 TopBit = 1 << 31 186 # IEEE 754 floats compare like sign-magnitude integers. 187 if (i & TopBit): # Negative float 188 return (1 << 32) - i # low half of integers, order reversed. 189 return TopBit + i # top half of integers 190 191 192def evaluate_func(forest_json, cpp_class): 193 """Generates evaluation functions for each tree and combines them in 194 `float Evaluate(const {Example}&)` function. This function can be 195 used to score an Example.""" 196 197 code = "" 198 199 # Generate evaluation function of each tree. 200 code += "namespace {\n" 201 tree_num = 0 202 for tree_json in forest_json: 203 code += "LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % (tree_num, cpp_class.name) 204 code += " " + \ 205 "\n ".join( 206 tree(tree_json, tree_num=tree_num, node_num=0)[0]) + "\n" 207 code += "}\n\n" 208 tree_num += 1 209 code += "} // namespace\n\n" 210 211 # Combine the scores of all trees in the final function. 212 # MSAN will timeout if these functions are inlined. 213 code += "float Evaluate(const %s& E) {\n" % cpp_class.name 214 code += " float Score = 0;\n" 215 for tree_num in range(len(forest_json)): 216 code += " Score += EvaluateTree%d(E);\n" % tree_num 217 code += " return Score;\n" 218 code += "}\n" 219 220 return code 221 222 223def gen_cpp_code(forest_json, features_json, filename, cpp_class): 224 """Generates code for the .cpp file.""" 225 # Headers 226 # Required by OrderEncode(float F). 227 angled_include = [ 228 '#include <%s>' % h 229 for h in ["cstring", "limits"] 230 ] 231 232 # Include generated header. 233 qouted_headers = {filename + '.h', 'llvm/ADT/bit.h'} 234 # Headers required by ENUM features used by the model. 235 qouted_headers |= {f["header"] 236 for f in features_json if f["kind"] == "ENUM"} 237 quoted_include = ['#include "%s"' % h for h in sorted(qouted_headers)] 238 239 # using-decl for ENUM features. 240 using_decls = "\n".join("using %s_type = %s;" % ( 241 feature['name'], feature['type']) 242 for feature in features_json 243 if feature["kind"] == "ENUM") 244 nl = "\n" 245 return """%s 246 247%s 248 249#define BIT(X) (1LL << X) 250 251%s 252 253%s 254 255uint32_t %s::OrderEncode(float F) { 256 static_assert(std::numeric_limits<float>::is_iec559, ""); 257 constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1); 258 259 // Get the bits of the float. Endianness is the same as for integers. 260 uint32_t U = llvm::bit_cast<uint32_t>(F); 261 std::memcpy(&U, &F, sizeof(U)); 262 // IEEE 754 floats compare like sign-magnitude integers. 263 if (U & TopBit) // Negative float. 264 return 0 - U; // Map onto the low half of integers, order reversed. 265 return U + TopBit; // Positive floats map onto the high half of integers. 266} 267 268%s 269%s 270""" % (nl.join(angled_include), nl.join(quoted_include), cpp_class.ns_begin(), 271 using_decls, cpp_class.name, evaluate_func(forest_json, cpp_class), 272 cpp_class.ns_end()) 273 274 275def main(): 276 parser = argparse.ArgumentParser('DecisionForestCodegen') 277 parser.add_argument('--filename', help='output file name.') 278 parser.add_argument('--output_dir', help='output directory.') 279 parser.add_argument('--model', help='path to model directory.') 280 parser.add_argument( 281 '--cpp_class', 282 help='The name of the class (which may be a namespace-qualified) created in generated header.' 283 ) 284 ns = parser.parse_args() 285 286 output_dir = ns.output_dir 287 filename = ns.filename 288 header_file = "%s/%s.h" % (output_dir, filename) 289 cpp_file = "%s/%s.cpp" % (output_dir, filename) 290 cpp_class = CppClass(cpp_class=ns.cpp_class) 291 292 model_file = "%s/forest.json" % ns.model 293 features_file = "%s/features.json" % ns.model 294 295 with open(features_file) as f: 296 features_json = json.load(f) 297 298 with open(model_file) as m: 299 forest_json = json.load(m) 300 301 with open(cpp_file, 'w+t') as output_cc: 302 output_cc.write( 303 gen_cpp_code(forest_json=forest_json, 304 features_json=features_json, 305 filename=filename, 306 cpp_class=cpp_class)) 307 308 with open(header_file, 'w+t') as output_h: 309 output_h.write(gen_header_code( 310 features_json=features_json, 311 cpp_class=cpp_class, 312 filename=filename)) 313 314 315if __name__ == '__main__': 316 main() 317