1 //===- VarLenCodeEmitterGen.cpp - CEG for variable-length insts -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The CodeEmitterGen component for variable-length instructions.
10 //
11 // The basic CodeEmitterGen is almost exclusively designed for fixed-
12 // length instructions. A good analogy for its encoding scheme is how printf
13 // works: The (immutable) formatting string represent the fixed values in the
14 // encoded instruction. Placeholders (i.e. %something), on the other hand,
15 // represent encoding for instruction operands.
16 // ```
17 // printf("1101 %src 1001 %dst", <encoded value for operand `src`>,
18 //                               <encoded value for operand `dst`>);
19 // ```
20 // VarLenCodeEmitterGen in this file provides an alternative encoding scheme
21 // that works more like a C++ stream operator:
22 // ```
23 // OS << 0b1101;
24 // if (Cond)
25 //   OS << OperandEncoding0;
26 // OS << 0b1001 << OperandEncoding1;
27 // ```
28 // You are free to concatenate arbitrary types (and sizes) of encoding
29 // fragments on any bit position, bringing more flexibilities on defining
30 // encoding for variable-length instructions.
31 //
32 // In a more specific way, instruction encoding is represented by a DAG type
33 // `Inst` field. Here is an example:
34 // ```
35 // dag Inst = (descend 0b1101, (operand "$src", 4), 0b1001,
36 //                     (operand "$dst", 4));
37 // ```
38 // It represents the following instruction encoding:
39 // ```
40 // MSB                                                     LSB
41 // 1101<encoding for operand src>1001<encoding for operand dst>
42 // ```
43 // For more details about DAG operators in the above snippet, please
44 // refer to \file include/llvm/Target/Target.td.
45 //
46 // VarLenCodeEmitter will convert the above DAG into the same helper function
47 // generated by CodeEmitter, `MCCodeEmitter::getBinaryCodeForInstr` (except
48 // for few details).
49 //
50 //===----------------------------------------------------------------------===//
51 
52 #include "VarLenCodeEmitterGen.h"
53 #include "CodeGenInstruction.h"
54 #include "CodeGenTarget.h"
55 #include "SubtargetFeatureInfo.h"
56 #include "llvm/ADT/ArrayRef.h"
57 #include "llvm/ADT/DenseMap.h"
58 #include "llvm/Support/raw_ostream.h"
59 #include "llvm/TableGen/Error.h"
60 #include "llvm/TableGen/Record.h"
61 
62 using namespace llvm;
63 
64 namespace {
65 
66 class VarLenCodeEmitterGen {
67   RecordKeeper &Records;
68 
69   struct EncodingSegment {
70     unsigned BitWidth;
71     const Init *Value;
72     StringRef CustomEncoder = "";
73   };
74 
75   class VarLenInst {
76     RecordVal *TheDef;
77     size_t NumBits;
78 
79     // Set if any of the segment is not fixed value.
80     bool HasDynamicSegment;
81 
82     SmallVector<EncodingSegment, 4> Segments;
83 
84     void buildRec(const DagInit *DI);
85 
86     StringRef getCustomEncoderName(const Init *EI) const {
87       if (const auto *DI = dyn_cast<DagInit>(EI)) {
88         if (DI->getNumArgs() && isa<StringInit>(DI->getArg(0)))
89           return cast<StringInit>(DI->getArg(0))->getValue();
90       }
91       return "";
92     }
93 
94   public:
95     VarLenInst() : TheDef(nullptr), NumBits(0U), HasDynamicSegment(false) {}
96 
97     explicit VarLenInst(const DagInit *DI, RecordVal *TheDef);
98 
99     /// Number of bits
100     size_t size() const { return NumBits; }
101 
102     using const_iterator = decltype(Segments)::const_iterator;
103 
104     const_iterator begin() const { return Segments.begin(); }
105     const_iterator end() const { return Segments.end(); }
106     size_t getNumSegments() const { return Segments.size(); }
107 
108     bool isFixedValueOnly() const { return !HasDynamicSegment; }
109   };
110 
111   DenseMap<Record *, VarLenInst> VarLenInsts;
112 
113   // Emit based values (i.e. fixed bits in the encoded instructions)
114   void emitInstructionBaseValues(
115       raw_ostream &OS,
116       ArrayRef<const CodeGenInstruction *> NumberedInstructions,
117       CodeGenTarget &Target, int HwMode = -1);
118 
119   std::string getInstructionCase(Record *R, CodeGenTarget &Target);
120   std::string getInstructionCaseForEncoding(Record *R, Record *EncodingDef,
121                                             CodeGenTarget &Target);
122 
123 public:
124   explicit VarLenCodeEmitterGen(RecordKeeper &R) : Records(R) {}
125 
126   void run(raw_ostream &OS);
127 };
128 
129 } // end anonymous namespace
130 
131 VarLenCodeEmitterGen::VarLenInst::VarLenInst(const DagInit *DI,
132                                              RecordVal *TheDef)
133     : TheDef(TheDef), NumBits(0U) {
134   buildRec(DI);
135   for (const auto &S : Segments)
136     NumBits += S.BitWidth;
137 }
138 
139 void VarLenCodeEmitterGen::VarLenInst::buildRec(const DagInit *DI) {
140   assert(TheDef && "The def record is nullptr ?");
141 
142   std::string Op = DI->getOperator()->getAsString();
143 
144   if (Op == "ascend" || Op == "descend") {
145     bool Reverse = Op == "descend";
146     int i = Reverse ? DI->getNumArgs() - 1 : 0;
147     int e = Reverse ? -1 : DI->getNumArgs();
148     int s = Reverse ? -1 : 1;
149     for (; i != e; i += s) {
150       const Init *Arg = DI->getArg(i);
151       if (const auto *BI = dyn_cast<BitsInit>(Arg)) {
152         if (!BI->isComplete())
153           PrintFatalError(TheDef->getLoc(),
154                           "Expecting complete bits init in `" + Op + "`");
155         Segments.push_back({BI->getNumBits(), BI});
156       } else if (const auto *BI = dyn_cast<BitInit>(Arg)) {
157         if (!BI->isConcrete())
158           PrintFatalError(TheDef->getLoc(),
159                           "Expecting concrete bit init in `" + Op + "`");
160         Segments.push_back({1, BI});
161       } else if (const auto *SubDI = dyn_cast<DagInit>(Arg)) {
162         buildRec(SubDI);
163       } else {
164         PrintFatalError(TheDef->getLoc(), "Unrecognized type of argument in `" +
165                                               Op + "`: " + Arg->getAsString());
166       }
167     }
168   } else if (Op == "operand") {
169     // (operand <operand name>, <# of bits>, [(encoder <custom encoder>)])
170     if (DI->getNumArgs() < 2)
171       PrintFatalError(TheDef->getLoc(),
172                       "Expecting at least 2 arguments for `operand`");
173     HasDynamicSegment = true;
174     const Init *OperandName = DI->getArg(0), *NumBits = DI->getArg(1);
175     if (!isa<StringInit>(OperandName) || !isa<IntInit>(NumBits))
176       PrintFatalError(TheDef->getLoc(), "Invalid argument types for `operand`");
177 
178     auto NumBitsVal = cast<IntInit>(NumBits)->getValue();
179     if (NumBitsVal <= 0)
180       PrintFatalError(TheDef->getLoc(), "Invalid number of bits for `operand`");
181 
182     StringRef CustomEncoder;
183     if (DI->getNumArgs() >= 3)
184       CustomEncoder = getCustomEncoderName(DI->getArg(2));
185     Segments.push_back(
186         {static_cast<unsigned>(NumBitsVal), OperandName, CustomEncoder});
187   } else if (Op == "slice") {
188     // (slice <operand name>, <high / low bit>, <low / high bit>,
189     //        [(encoder <custom encoder>)])
190     if (DI->getNumArgs() < 3)
191       PrintFatalError(TheDef->getLoc(),
192                       "Expecting at least 3 arguments for `slice`");
193     HasDynamicSegment = true;
194     Init *OperandName = DI->getArg(0), *HiBit = DI->getArg(1),
195          *LoBit = DI->getArg(2);
196     if (!isa<StringInit>(OperandName) || !isa<IntInit>(HiBit) ||
197         !isa<IntInit>(LoBit))
198       PrintFatalError(TheDef->getLoc(), "Invalid argument types for `slice`");
199 
200     auto HiBitVal = cast<IntInit>(HiBit)->getValue(),
201          LoBitVal = cast<IntInit>(LoBit)->getValue();
202     if (HiBitVal < 0 || LoBitVal < 0)
203       PrintFatalError(TheDef->getLoc(), "Invalid bit range for `slice`");
204     bool NeedSwap = false;
205     unsigned NumBits = 0U;
206     if (HiBitVal < LoBitVal) {
207       NeedSwap = true;
208       NumBits = static_cast<unsigned>(LoBitVal - HiBitVal + 1);
209     } else {
210       NumBits = static_cast<unsigned>(HiBitVal - LoBitVal + 1);
211     }
212 
213     StringRef CustomEncoder;
214     if (DI->getNumArgs() >= 4)
215       CustomEncoder = getCustomEncoderName(DI->getArg(3));
216 
217     if (NeedSwap) {
218       // Normalization: Hi bit should always be the second argument.
219       Init *const NewArgs[] = {OperandName, LoBit, HiBit};
220       Segments.push_back({NumBits,
221                           DagInit::get(DI->getOperator(), nullptr, NewArgs, {}),
222                           CustomEncoder});
223     } else {
224       Segments.push_back({NumBits, DI, CustomEncoder});
225     }
226   }
227 }
228 
229 void VarLenCodeEmitterGen::run(raw_ostream &OS) {
230   CodeGenTarget Target(Records);
231   auto Insts = Records.getAllDerivedDefinitions("Instruction");
232 
233   auto NumberedInstructions = Target.getInstructionsByEnumValue();
234   const CodeGenHwModes &HWM = Target.getHwModes();
235 
236   // The set of HwModes used by instruction encodings.
237   std::set<unsigned> HwModes;
238   for (const CodeGenInstruction *CGI : NumberedInstructions) {
239     Record *R = CGI->TheDef;
240 
241     // Create the corresponding VarLenInst instance.
242     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
243         R->getValueAsBit("isPseudo"))
244       continue;
245 
246     if (const RecordVal *RV = R->getValue("EncodingInfos")) {
247       if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
248         EncodingInfoByHwMode EBM(DI->getDef(), HWM);
249         for (auto &KV : EBM) {
250           HwModes.insert(KV.first);
251           Record *EncodingDef = KV.second;
252           RecordVal *RV = EncodingDef->getValue("Inst");
253           DagInit *DI = cast<DagInit>(RV->getValue());
254           VarLenInsts.insert({EncodingDef, VarLenInst(DI, RV)});
255         }
256         continue;
257       }
258     }
259     RecordVal *RV = R->getValue("Inst");
260     DagInit *DI = cast<DagInit>(RV->getValue());
261     VarLenInsts.insert({R, VarLenInst(DI, RV)});
262   }
263 
264   // Emit function declaration
265   OS << "void " << Target.getName()
266      << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
267      << "    SmallVectorImpl<MCFixup> &Fixups,\n"
268      << "    APInt &Inst,\n"
269      << "    APInt &Scratch,\n"
270      << "    const MCSubtargetInfo &STI) const {\n";
271 
272   // Emit instruction base values
273   if (HwModes.empty()) {
274     emitInstructionBaseValues(OS, NumberedInstructions, Target);
275   } else {
276     for (unsigned HwMode : HwModes)
277       emitInstructionBaseValues(OS, NumberedInstructions, Target, (int)HwMode);
278   }
279 
280   if (!HwModes.empty()) {
281     OS << "  const unsigned **Index;\n";
282     OS << "  const uint64_t *InstBits;\n";
283     OS << "  unsigned HwMode = STI.getHwMode();\n";
284     OS << "  switch (HwMode) {\n";
285     OS << "  default: llvm_unreachable(\"Unknown hardware mode!\"); break;\n";
286     for (unsigned I : HwModes) {
287       OS << "  case " << I << ": InstBits = InstBits_" << HWM.getMode(I).Name
288          << "; Index = Index_" << HWM.getMode(I).Name << "; break;\n";
289     }
290     OS << "  };\n";
291   }
292 
293   // Emit helper function to retrieve base values.
294   OS << "  auto getInstBits = [&](unsigned Opcode) -> APInt {\n"
295      << "    unsigned NumBits = Index[Opcode][0];\n"
296      << "    if (!NumBits)\n"
297      << "      return APInt::getZeroWidth();\n"
298      << "    unsigned Idx = Index[Opcode][1];\n"
299      << "    ArrayRef<uint64_t> Data(&InstBits[Idx], "
300      << "APInt::getNumWords(NumBits));\n"
301      << "    return APInt(NumBits, Data);\n"
302      << "  };\n";
303 
304   // Map to accumulate all the cases.
305   std::map<std::string, std::vector<std::string>> CaseMap;
306 
307   // Construct all cases statement for each opcode
308   for (Record *R : Insts) {
309     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
310         R->getValueAsBit("isPseudo"))
311       continue;
312     std::string InstName =
313         (R->getValueAsString("Namespace") + "::" + R->getName()).str();
314     std::string Case = getInstructionCase(R, Target);
315 
316     CaseMap[Case].push_back(std::move(InstName));
317   }
318 
319   // Emit initial function code
320   OS << "  const unsigned opcode = MI.getOpcode();\n"
321      << "  switch (opcode) {\n";
322 
323   // Emit each case statement
324   for (const auto &C : CaseMap) {
325     const std::string &Case = C.first;
326     const auto &InstList = C.second;
327 
328     ListSeparator LS("\n");
329     for (const auto &InstName : InstList)
330       OS << LS << "    case " << InstName << ":";
331 
332     OS << " {\n";
333     OS << Case;
334     OS << "      break;\n"
335        << "    }\n";
336   }
337   // Default case: unhandled opcode
338   OS << "  default:\n"
339      << "    std::string msg;\n"
340      << "    raw_string_ostream Msg(msg);\n"
341      << "    Msg << \"Not supported instr: \" << MI;\n"
342      << "    report_fatal_error(Msg.str().c_str());\n"
343      << "  }\n";
344   OS << "}\n\n";
345 }
346 
347 static void emitInstBits(raw_ostream &IS, raw_ostream &SS, const APInt &Bits,
348                          unsigned &Index) {
349   if (!Bits.getNumWords()) {
350     IS.indent(4) << "{/*NumBits*/0, /*Index*/0},";
351     return;
352   }
353 
354   IS.indent(4) << "{/*NumBits*/" << Bits.getBitWidth() << ", "
355                << "/*Index*/" << Index << "},";
356 
357   SS.indent(4);
358   for (unsigned I = 0; I < Bits.getNumWords(); ++I, ++Index)
359     SS << "UINT64_C(" << utostr(Bits.getRawData()[I]) << "),";
360 }
361 
362 void VarLenCodeEmitterGen::emitInstructionBaseValues(
363     raw_ostream &OS, ArrayRef<const CodeGenInstruction *> NumberedInstructions,
364     CodeGenTarget &Target, int HwMode) {
365   std::string IndexArray, StorageArray;
366   raw_string_ostream IS(IndexArray), SS(StorageArray);
367 
368   const CodeGenHwModes &HWM = Target.getHwModes();
369   if (HwMode == -1) {
370     IS << "  static const unsigned Index[][2] = {\n";
371     SS << "  static const uint64_t InstBits[] = {\n";
372   } else {
373     StringRef Name = HWM.getMode(HwMode).Name;
374     IS << "  static const unsigned Index_" << Name << "[][2] = {\n";
375     SS << "  static const uint64_t InstBits_" << Name << "[] = {\n";
376   }
377 
378   unsigned NumFixedValueWords = 0U;
379   for (const CodeGenInstruction *CGI : NumberedInstructions) {
380     Record *R = CGI->TheDef;
381 
382     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
383         R->getValueAsBit("isPseudo")) {
384       IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\n";
385       continue;
386     }
387 
388     Record *EncodingDef = R;
389     if (const RecordVal *RV = R->getValue("EncodingInfos")) {
390       if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
391         EncodingInfoByHwMode EBM(DI->getDef(), HWM);
392         if (EBM.hasMode(HwMode))
393           EncodingDef = EBM.get(HwMode);
394       }
395     }
396 
397     auto It = VarLenInsts.find(EncodingDef);
398     if (It == VarLenInsts.end())
399       PrintFatalError(EncodingDef, "VarLenInst not found for this record");
400     const VarLenInst &VLI = It->second;
401 
402     unsigned i = 0U, BitWidth = VLI.size();
403 
404     // Start by filling in fixed values.
405     APInt Value(BitWidth, 0);
406     auto SI = VLI.begin(), SE = VLI.end();
407     // Scan through all the segments that have fixed-bits values.
408     while (i < BitWidth && SI != SE) {
409       unsigned SegmentNumBits = SI->BitWidth;
410       if (const auto *BI = dyn_cast<BitsInit>(SI->Value)) {
411         for (unsigned Idx = 0U; Idx != SegmentNumBits; ++Idx) {
412           auto *B = cast<BitInit>(BI->getBit(Idx));
413           Value.setBitVal(i + Idx, B->getValue());
414         }
415       }
416       if (const auto *BI = dyn_cast<BitInit>(SI->Value))
417         Value.setBitVal(i, BI->getValue());
418 
419       i += SegmentNumBits;
420       ++SI;
421     }
422 
423     emitInstBits(IS, SS, Value, NumFixedValueWords);
424     IS << '\t' << "// " << R->getName() << "\n";
425     if (Value.getNumWords())
426       SS << '\t' << "// " << R->getName() << "\n";
427   }
428   IS.indent(4) << "{/*NumBits*/0, /*Index*/0}\n  };\n";
429   SS.indent(4) << "UINT64_C(0)\n  };\n";
430 
431   OS << IS.str() << SS.str();
432 }
433 
434 std::string VarLenCodeEmitterGen::getInstructionCase(Record *R,
435                                                      CodeGenTarget &Target) {
436   std::string Case;
437   if (const RecordVal *RV = R->getValue("EncodingInfos")) {
438     if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
439       const CodeGenHwModes &HWM = Target.getHwModes();
440       EncodingInfoByHwMode EBM(DI->getDef(), HWM);
441       Case += "      switch (HwMode) {\n";
442       Case += "      default: llvm_unreachable(\"Unhandled HwMode\");\n";
443       for (auto &KV : EBM) {
444         Case += "      case " + itostr(KV.first) + ": {\n";
445         Case += getInstructionCaseForEncoding(R, KV.second, Target);
446         Case += "      break;\n";
447         Case += "      }\n";
448       }
449       Case += "      }\n";
450       return Case;
451     }
452   }
453   return getInstructionCaseForEncoding(R, R, Target);
454 }
455 
456 std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
457     Record *R, Record *EncodingDef, CodeGenTarget &Target) {
458   auto It = VarLenInsts.find(EncodingDef);
459   if (It == VarLenInsts.end())
460     PrintFatalError(EncodingDef, "Parsed encoding record not found");
461   const VarLenInst &VLI = It->second;
462   size_t BitWidth = VLI.size();
463 
464   CodeGenInstruction &CGI = Target.getInstruction(R);
465 
466   std::string Case;
467   raw_string_ostream SS(Case);
468   // Resize the scratch buffer.
469   if (BitWidth && !VLI.isFixedValueOnly())
470     SS.indent(6) << "Scratch = Scratch.zextOrSelf(" << BitWidth << ");\n";
471   // Populate based value.
472   SS.indent(6) << "Inst = getInstBits(opcode);\n";
473 
474   // Process each segment in VLI.
475   size_t Offset = 0U;
476   for (const auto &ES : VLI) {
477     unsigned NumBits = ES.BitWidth;
478     const Init *Val = ES.Value;
479     // If it's a StringInit or DagInit, it's a reference to an operand
480     // or part of an operand.
481     if (isa<StringInit>(Val) || isa<DagInit>(Val)) {
482       StringRef OperandName;
483       unsigned LoBit = 0U;
484       if (const auto *SV = dyn_cast<StringInit>(Val)) {
485         OperandName = SV->getValue();
486       } else {
487         // Normalized: (slice <operand name>, <high bit>, <low bit>)
488         const auto *DV = cast<DagInit>(Val);
489         OperandName = cast<StringInit>(DV->getArg(0))->getValue();
490         LoBit = static_cast<unsigned>(cast<IntInit>(DV->getArg(2))->getValue());
491       }
492 
493       auto OpIdx = CGI.Operands.ParseOperandName(OperandName);
494       unsigned FlatOpIdx = CGI.Operands.getFlattenedOperandNumber(OpIdx);
495       StringRef CustomEncoder = CGI.Operands[OpIdx.first].EncoderMethodName;
496       if (ES.CustomEncoder.size())
497         CustomEncoder = ES.CustomEncoder;
498 
499       SS.indent(6) << "Scratch.clearAllBits();\n";
500       SS.indent(6) << "// op: " << OperandName.drop_front(1) << "\n";
501       if (CustomEncoder.empty())
502         SS.indent(6) << "getMachineOpValue(MI, MI.getOperand("
503                      << utostr(FlatOpIdx) << ")";
504       else
505         SS.indent(6) << CustomEncoder << "(MI, /*OpIdx=*/" << utostr(FlatOpIdx);
506 
507       SS << ", /*Pos=*/" << utostr(Offset) << ", Scratch, Fixups, STI);\n";
508 
509       SS.indent(6) << "Inst.insertBits("
510                    << "Scratch.extractBits(" << utostr(NumBits) << ", "
511                    << utostr(LoBit) << ")"
512                    << ", " << Offset << ");\n";
513     }
514     Offset += NumBits;
515   }
516 
517   StringRef PostEmitter = R->getValueAsString("PostEncoderMethod");
518   if (!PostEmitter.empty())
519     SS.indent(6) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
520 
521   return Case;
522 }
523 
524 namespace llvm {
525 
526 void emitVarLenCodeEmitter(RecordKeeper &R, raw_ostream &OS) {
527   VarLenCodeEmitterGen(R).run(OS);
528 }
529 
530 } // end namespace llvm
531