1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 #include "mlir/Dialect/EmitC/IR/EmitC.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Support/IndentedOstream.h"
18 #include "mlir/Target/Cpp/CppEmitter.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <utility>
26 
27 #define DEBUG_TYPE "translate-to-cpp"
28 
29 using namespace mlir;
30 using namespace mlir::emitc;
31 using llvm::formatv;
32 
33 /// Convenience functions to produce interleaved output with functions returning
34 /// a LogicalResult. This is different than those in STLExtras as functions used
35 /// on each element doesn't return a string.
36 template <typename ForwardIterator, typename UnaryFunctor,
37           typename NullaryFunctor>
38 inline LogicalResult
39 interleaveWithError(ForwardIterator begin, ForwardIterator end,
40                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
41   if (begin == end)
42     return success();
43   if (failed(eachFn(*begin)))
44     return failure();
45   ++begin;
46   for (; begin != end; ++begin) {
47     betweenFn();
48     if (failed(eachFn(*begin)))
49       return failure();
50   }
51   return success();
52 }
53 
54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
55 inline LogicalResult interleaveWithError(const Container &c,
56                                          UnaryFunctor eachFn,
57                                          NullaryFunctor betweenFn) {
58   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
59 }
60 
61 template <typename Container, typename UnaryFunctor>
62 inline LogicalResult interleaveCommaWithError(const Container &c,
63                                               raw_ostream &os,
64                                               UnaryFunctor eachFn) {
65   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
66 }
67 
68 namespace {
69 /// Emitter that uses dialect specific emitters to emit C++ code.
70 struct CppEmitter {
71   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
72 
73   /// Emits attribute or returns failure.
74   LogicalResult emitAttribute(Location loc, Attribute attr);
75 
76   /// Emits operation 'op' with/without training semicolon or returns failure.
77   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
78 
79   /// Emits type 'type' or returns failure.
80   LogicalResult emitType(Location loc, Type type);
81 
82   /// Emits array of types as a std::tuple of the emitted types.
83   /// - emits void for an empty array;
84   /// - emits the type of the only element for arrays of size one;
85   /// - emits a std::tuple otherwise;
86   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
87 
88   /// Emits array of types as a std::tuple of the emitted types independently of
89   /// the array size.
90   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
91 
92   /// Emits an assignment for a variable which has been declared previously.
93   LogicalResult emitVariableAssignment(OpResult result);
94 
95   /// Emits a variable declaration for a result of an operation.
96   LogicalResult emitVariableDeclaration(OpResult result,
97                                         bool trailingSemicolon);
98 
99   /// Emits the variable declaration and assignment prefix for 'op'.
100   /// - emits separate variable followed by std::tie for multi-valued operation;
101   /// - emits single type followed by variable for single result;
102   /// - emits nothing if no value produced by op;
103   /// Emits final '=' operator where a type is produced. Returns failure if
104   /// any result type could not be converted.
105   LogicalResult emitAssignPrefix(Operation &op);
106 
107   /// Emits a label for the block.
108   LogicalResult emitLabel(Block &block);
109 
110   /// Emits the operands and atttributes of the operation. All operands are
111   /// emitted first and then all attributes in alphabetical order.
112   LogicalResult emitOperandsAndAttributes(Operation &op,
113                                           ArrayRef<StringRef> exclude = {});
114 
115   /// Emits the operands of the operation. All operands are emitted in order.
116   LogicalResult emitOperands(Operation &op);
117 
118   /// Return the existing or a new name for a Value.
119   StringRef getOrCreateName(Value val);
120 
121   /// Return the existing or a new label of a Block.
122   StringRef getOrCreateName(Block &block);
123 
124   /// Whether to map an mlir integer to a unsigned integer in C++.
125   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
126 
127   /// RAII helper function to manage entering/exiting C++ scopes.
128   struct Scope {
129     Scope(CppEmitter &emitter)
130         : valueMapperScope(emitter.valueMapper),
131           blockMapperScope(emitter.blockMapper), emitter(emitter) {
132       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
133       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
134     }
135     ~Scope() {
136       emitter.valueInScopeCount.pop();
137       emitter.labelInScopeCount.pop();
138     }
139 
140   private:
141     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
142     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
143     CppEmitter &emitter;
144   };
145 
146   /// Returns wether the Value is assigned to a C++ variable in the scope.
147   bool hasValueInScope(Value val);
148 
149   // Returns whether a label is assigned to the block.
150   bool hasBlockLabel(Block &block);
151 
152   /// Returns the output stream.
153   raw_indented_ostream &ostream() { return os; };
154 
155   /// Returns if all variables for op results and basic block arguments need to
156   /// be declared at the beginning of a function.
157   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
158 
159 private:
160   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
161   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
162 
163   /// Output stream to emit to.
164   raw_indented_ostream os;
165 
166   /// Boolean to enforce that all variables for op results and block
167   /// arguments are declared at the beginning of the function. This also
168   /// includes results from ops located in nested regions.
169   bool declareVariablesAtTop;
170 
171   /// Map from value to name of C++ variable that contain the name.
172   ValueMapper valueMapper;
173 
174   /// Map from block to name of C++ label.
175   BlockMapper blockMapper;
176 
177   /// The number of values in the current scope. This is used to declare the
178   /// names of values in a scope.
179   std::stack<int64_t> valueInScopeCount;
180   std::stack<int64_t> labelInScopeCount;
181 };
182 } // namespace
183 
184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
185                                      Attribute value) {
186   OpResult result = operation->getResult(0);
187 
188   // Only emit an assignment as the variable was already declared when printing
189   // the FuncOp.
190   if (emitter.shouldDeclareVariablesAtTop()) {
191     // Skip the assignment if the emitc.constant has no value.
192     if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
193       if (oAttr.getValue().empty())
194         return success();
195     }
196 
197     if (failed(emitter.emitVariableAssignment(result)))
198       return failure();
199     return emitter.emitAttribute(operation->getLoc(), value);
200   }
201 
202   // Emit a variable declaration for an emitc.constant op without value.
203   if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
204     if (oAttr.getValue().empty())
205       // The semicolon gets printed by the emitOperation function.
206       return emitter.emitVariableDeclaration(result,
207                                              /*trailingSemicolon=*/false);
208   }
209 
210   // Emit a variable declaration.
211   if (failed(emitter.emitAssignPrefix(*operation)))
212     return failure();
213   return emitter.emitAttribute(operation->getLoc(), value);
214 }
215 
216 static LogicalResult printOperation(CppEmitter &emitter,
217                                     emitc::ConstantOp constantOp) {
218   Operation *operation = constantOp.getOperation();
219   Attribute value = constantOp.value();
220 
221   return printConstantOp(emitter, operation, value);
222 }
223 
224 static LogicalResult printOperation(CppEmitter &emitter,
225                                     emitc::VariableOp variableOp) {
226   Operation *operation = variableOp.getOperation();
227   Attribute value = variableOp.value();
228 
229   return printConstantOp(emitter, operation, value);
230 }
231 
232 static LogicalResult printOperation(CppEmitter &emitter,
233                                     arith::ConstantOp constantOp) {
234   Operation *operation = constantOp.getOperation();
235   Attribute value = constantOp.getValue();
236 
237   return printConstantOp(emitter, operation, value);
238 }
239 
240 static LogicalResult printOperation(CppEmitter &emitter,
241                                     func::ConstantOp constantOp) {
242   Operation *operation = constantOp.getOperation();
243   Attribute value = constantOp.getValueAttr();
244 
245   return printConstantOp(emitter, operation, value);
246 }
247 
248 static LogicalResult printOperation(CppEmitter &emitter,
249                                     cf::BranchOp branchOp) {
250   raw_ostream &os = emitter.ostream();
251   Block &successor = *branchOp.getSuccessor();
252 
253   for (auto pair :
254        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
255     Value &operand = std::get<0>(pair);
256     BlockArgument &argument = std::get<1>(pair);
257     os << emitter.getOrCreateName(argument) << " = "
258        << emitter.getOrCreateName(operand) << ";\n";
259   }
260 
261   os << "goto ";
262   if (!(emitter.hasBlockLabel(successor)))
263     return branchOp.emitOpError("unable to find label for successor block");
264   os << emitter.getOrCreateName(successor);
265   return success();
266 }
267 
268 static LogicalResult printOperation(CppEmitter &emitter,
269                                     cf::CondBranchOp condBranchOp) {
270   raw_indented_ostream &os = emitter.ostream();
271   Block &trueSuccessor = *condBranchOp.getTrueDest();
272   Block &falseSuccessor = *condBranchOp.getFalseDest();
273 
274   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
275      << ") {\n";
276 
277   os.indent();
278 
279   // If condition is true.
280   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
281                              trueSuccessor.getArguments())) {
282     Value &operand = std::get<0>(pair);
283     BlockArgument &argument = std::get<1>(pair);
284     os << emitter.getOrCreateName(argument) << " = "
285        << emitter.getOrCreateName(operand) << ";\n";
286   }
287 
288   os << "goto ";
289   if (!(emitter.hasBlockLabel(trueSuccessor))) {
290     return condBranchOp.emitOpError("unable to find label for successor block");
291   }
292   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
293   os.unindent() << "} else {\n";
294   os.indent();
295   // If condition is false.
296   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
297                              falseSuccessor.getArguments())) {
298     Value &operand = std::get<0>(pair);
299     BlockArgument &argument = std::get<1>(pair);
300     os << emitter.getOrCreateName(argument) << " = "
301        << emitter.getOrCreateName(operand) << ";\n";
302   }
303 
304   os << "goto ";
305   if (!(emitter.hasBlockLabel(falseSuccessor))) {
306     return condBranchOp.emitOpError()
307            << "unable to find label for successor block";
308   }
309   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
310   os.unindent() << "}";
311   return success();
312 }
313 
314 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
315   if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
316     return failure();
317 
318   raw_ostream &os = emitter.ostream();
319   os << callOp.getCallee() << "(";
320   if (failed(emitter.emitOperands(*callOp.getOperation())))
321     return failure();
322   os << ")";
323   return success();
324 }
325 
326 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
327   raw_ostream &os = emitter.ostream();
328   Operation &op = *callOp.getOperation();
329 
330   if (failed(emitter.emitAssignPrefix(op)))
331     return failure();
332   os << callOp.callee();
333 
334   auto emitArgs = [&](Attribute attr) -> LogicalResult {
335     if (auto t = attr.dyn_cast<IntegerAttr>()) {
336       // Index attributes are treated specially as operand index.
337       if (t.getType().isIndex()) {
338         int64_t idx = t.getInt();
339         if ((idx < 0) || (idx >= op.getNumOperands()))
340           return op.emitOpError("invalid operand index");
341         if (!emitter.hasValueInScope(op.getOperand(idx)))
342           return op.emitOpError("operand ")
343                  << idx << "'s value not defined in scope";
344         os << emitter.getOrCreateName(op.getOperand(idx));
345         return success();
346       }
347     }
348     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
349       return failure();
350 
351     return success();
352   };
353 
354   if (callOp.template_args()) {
355     os << "<";
356     if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
357       return failure();
358     os << ">";
359   }
360 
361   os << "(";
362 
363   LogicalResult emittedArgs =
364       callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
365                     : emitter.emitOperands(op);
366   if (failed(emittedArgs))
367     return failure();
368   os << ")";
369   return success();
370 }
371 
372 static LogicalResult printOperation(CppEmitter &emitter,
373                                     emitc::ApplyOp applyOp) {
374   raw_ostream &os = emitter.ostream();
375   Operation &op = *applyOp.getOperation();
376 
377   if (failed(emitter.emitAssignPrefix(op)))
378     return failure();
379   os << applyOp.applicableOperator();
380   os << emitter.getOrCreateName(applyOp.getOperand());
381 
382   return success();
383 }
384 
385 static LogicalResult printOperation(CppEmitter &emitter,
386                                     emitc::IncludeOp includeOp) {
387   raw_ostream &os = emitter.ostream();
388 
389   os << "#include ";
390   if (includeOp.is_standard_include())
391     os << "<" << includeOp.include() << ">";
392   else
393     os << "\"" << includeOp.include() << "\"";
394 
395   return success();
396 }
397 
398 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
399 
400   raw_indented_ostream &os = emitter.ostream();
401 
402   OperandRange operands = forOp.getIterOperands();
403   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
404   Operation::result_range results = forOp.getResults();
405 
406   if (!emitter.shouldDeclareVariablesAtTop()) {
407     for (OpResult result : results) {
408       if (failed(emitter.emitVariableDeclaration(result,
409                                                  /*trailingSemicolon=*/true)))
410         return failure();
411     }
412   }
413 
414   for (auto pair : llvm::zip(iterArgs, operands)) {
415     if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
416       return failure();
417     os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
418     os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
419     os << "\n";
420   }
421 
422   os << "for (";
423   if (failed(
424           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
425     return failure();
426   os << " ";
427   os << emitter.getOrCreateName(forOp.getInductionVar());
428   os << " = ";
429   os << emitter.getOrCreateName(forOp.getLowerBound());
430   os << "; ";
431   os << emitter.getOrCreateName(forOp.getInductionVar());
432   os << " < ";
433   os << emitter.getOrCreateName(forOp.getUpperBound());
434   os << "; ";
435   os << emitter.getOrCreateName(forOp.getInductionVar());
436   os << " += ";
437   os << emitter.getOrCreateName(forOp.getStep());
438   os << ") {\n";
439   os.indent();
440 
441   Region &forRegion = forOp.getRegion();
442   auto regionOps = forRegion.getOps();
443 
444   // We skip the trailing yield op because this updates the result variables
445   // of the for op in the generated code. Instead we update the iterArgs at
446   // the end of a loop iteration and set the result variables after the for
447   // loop.
448   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
449     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
450       return failure();
451   }
452 
453   Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
454   // Copy yield operands into iterArgs at the end of a loop iteration.
455   for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
456     BlockArgument iterArg = std::get<0>(pair);
457     Value operand = std::get<1>(pair);
458     os << emitter.getOrCreateName(iterArg) << " = "
459        << emitter.getOrCreateName(operand) << ";\n";
460   }
461 
462   os.unindent() << "}";
463 
464   // Copy iterArgs into results after the for loop.
465   for (auto pair : llvm::zip(results, iterArgs)) {
466     OpResult result = std::get<0>(pair);
467     BlockArgument iterArg = std::get<1>(pair);
468     os << "\n"
469        << emitter.getOrCreateName(result) << " = "
470        << emitter.getOrCreateName(iterArg) << ";";
471   }
472 
473   return success();
474 }
475 
476 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
477   raw_indented_ostream &os = emitter.ostream();
478 
479   if (!emitter.shouldDeclareVariablesAtTop()) {
480     for (OpResult result : ifOp.getResults()) {
481       if (failed(emitter.emitVariableDeclaration(result,
482                                                  /*trailingSemicolon=*/true)))
483         return failure();
484     }
485   }
486 
487   os << "if (";
488   if (failed(emitter.emitOperands(*ifOp.getOperation())))
489     return failure();
490   os << ") {\n";
491   os.indent();
492 
493   Region &thenRegion = ifOp.getThenRegion();
494   for (Operation &op : thenRegion.getOps()) {
495     // Note: This prints a superfluous semicolon if the terminating yield op has
496     // zero results.
497     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
498       return failure();
499   }
500 
501   os.unindent() << "}";
502 
503   Region &elseRegion = ifOp.getElseRegion();
504   if (!elseRegion.empty()) {
505     os << " else {\n";
506     os.indent();
507 
508     for (Operation &op : elseRegion.getOps()) {
509       // Note: This prints a superfluous semicolon if the terminating yield op
510       // has zero results.
511       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
512         return failure();
513     }
514 
515     os.unindent() << "}";
516   }
517 
518   return success();
519 }
520 
521 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
522   raw_ostream &os = emitter.ostream();
523   Operation &parentOp = *yieldOp.getOperation()->getParentOp();
524 
525   if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
526     return yieldOp.emitError("number of operands does not to match the number "
527                              "of the parent op's results");
528   }
529 
530   if (failed(interleaveWithError(
531           llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
532           [&](auto pair) -> LogicalResult {
533             auto result = std::get<0>(pair);
534             auto operand = std::get<1>(pair);
535             os << emitter.getOrCreateName(result) << " = ";
536 
537             if (!emitter.hasValueInScope(operand))
538               return yieldOp.emitError("operand value not in scope");
539             os << emitter.getOrCreateName(operand);
540             return success();
541           },
542           [&]() { os << ";\n"; })))
543     return failure();
544 
545   return success();
546 }
547 
548 static LogicalResult printOperation(CppEmitter &emitter,
549                                     func::ReturnOp returnOp) {
550   raw_ostream &os = emitter.ostream();
551   os << "return";
552   switch (returnOp.getNumOperands()) {
553   case 0:
554     return success();
555   case 1:
556     os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
557     return success(emitter.hasValueInScope(returnOp.getOperand(0)));
558   default:
559     os << " std::make_tuple(";
560     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
561       return failure();
562     os << ")";
563     return success();
564   }
565 }
566 
567 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
568   CppEmitter::Scope scope(emitter);
569 
570   for (Operation &op : moduleOp) {
571     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
572       return failure();
573   }
574   return success();
575 }
576 
577 static LogicalResult printOperation(CppEmitter &emitter,
578                                     func::FuncOp functionOp) {
579   // We need to declare variables at top if the function has multiple blocks.
580   if (!emitter.shouldDeclareVariablesAtTop() &&
581       functionOp.getBlocks().size() > 1) {
582     return functionOp.emitOpError(
583         "with multiple blocks needs variables declared at top");
584   }
585 
586   CppEmitter::Scope scope(emitter);
587   raw_indented_ostream &os = emitter.ostream();
588   if (failed(emitter.emitTypes(functionOp.getLoc(),
589                                functionOp.getFunctionType().getResults())))
590     return failure();
591   os << " " << functionOp.getName();
592 
593   os << "(";
594   if (failed(interleaveCommaWithError(
595           functionOp.getArguments(), os,
596           [&](BlockArgument arg) -> LogicalResult {
597             if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
598               return failure();
599             os << " " << emitter.getOrCreateName(arg);
600             return success();
601           })))
602     return failure();
603   os << ") {\n";
604   os.indent();
605   if (emitter.shouldDeclareVariablesAtTop()) {
606     // Declare all variables that hold op results including those from nested
607     // regions.
608     WalkResult result =
609         functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
610           for (OpResult result : op->getResults()) {
611             if (failed(emitter.emitVariableDeclaration(
612                     result, /*trailingSemicolon=*/true))) {
613               return WalkResult(
614                   op->emitError("unable to declare result variable for op"));
615             }
616           }
617           return WalkResult::advance();
618         });
619     if (result.wasInterrupted())
620       return failure();
621   }
622 
623   Region::BlockListType &blocks = functionOp.getBlocks();
624   // Create label names for basic blocks.
625   for (Block &block : blocks) {
626     emitter.getOrCreateName(block);
627   }
628 
629   // Declare variables for basic block arguments.
630   for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
631     Block &block = *it;
632     for (BlockArgument &arg : block.getArguments()) {
633       if (emitter.hasValueInScope(arg))
634         return functionOp.emitOpError(" block argument #")
635                << arg.getArgNumber() << " is out of scope";
636       if (failed(
637               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
638         return failure();
639       }
640       os << " " << emitter.getOrCreateName(arg) << ";\n";
641     }
642   }
643 
644   for (Block &block : blocks) {
645     // Only print a label if the block has predecessors.
646     if (!block.hasNoPredecessors()) {
647       if (failed(emitter.emitLabel(block)))
648         return failure();
649     }
650     for (Operation &op : block.getOperations()) {
651       // When generating code for an scf.if or cf.cond_br op no semicolon needs
652       // to be printed after the closing brace.
653       // When generating code for an scf.for op, printing a trailing semicolon
654       // is handled within the printOperation function.
655       bool trailingSemicolon =
656           !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
657 
658       if (failed(emitter.emitOperation(
659               op, /*trailingSemicolon=*/trailingSemicolon)))
660         return failure();
661     }
662   }
663   os.unindent() << "}\n";
664   return success();
665 }
666 
667 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
668     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
669   valueInScopeCount.push(0);
670   labelInScopeCount.push(0);
671 }
672 
673 /// Return the existing or a new name for a Value.
674 StringRef CppEmitter::getOrCreateName(Value val) {
675   if (!valueMapper.count(val))
676     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
677   return *valueMapper.begin(val);
678 }
679 
680 /// Return the existing or a new label for a Block.
681 StringRef CppEmitter::getOrCreateName(Block &block) {
682   if (!blockMapper.count(&block))
683     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
684   return *blockMapper.begin(&block);
685 }
686 
687 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
688   switch (val) {
689   case IntegerType::Signless:
690     return false;
691   case IntegerType::Signed:
692     return false;
693   case IntegerType::Unsigned:
694     return true;
695   }
696   llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
697 }
698 
699 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
700 
701 bool CppEmitter::hasBlockLabel(Block &block) {
702   return blockMapper.count(&block);
703 }
704 
705 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
706   auto printInt = [&](const APInt &val, bool isUnsigned) {
707     if (val.getBitWidth() == 1) {
708       if (val.getBoolValue())
709         os << "true";
710       else
711         os << "false";
712     } else {
713       SmallString<128> strValue;
714       val.toString(strValue, 10, !isUnsigned, false);
715       os << strValue;
716     }
717   };
718 
719   auto printFloat = [&](const APFloat &val) {
720     if (val.isFinite()) {
721       SmallString<128> strValue;
722       // Use default values of toString except don't truncate zeros.
723       val.toString(strValue, 0, 0, false);
724       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
725       case llvm::APFloatBase::S_IEEEsingle:
726         os << "(float)";
727         break;
728       case llvm::APFloatBase::S_IEEEdouble:
729         os << "(double)";
730         break;
731       default:
732         break;
733       };
734       os << strValue;
735     } else if (val.isNaN()) {
736       os << "NAN";
737     } else if (val.isInfinity()) {
738       if (val.isNegative())
739         os << "-";
740       os << "INFINITY";
741     }
742   };
743 
744   // Print floating point attributes.
745   if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
746     printFloat(fAttr.getValue());
747     return success();
748   }
749   if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
750     os << '{';
751     interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
752     os << '}';
753     return success();
754   }
755 
756   // Print integer attributes.
757   if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
758     if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
759       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
760       return success();
761     }
762     if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
763       printInt(iAttr.getValue(), false);
764       return success();
765     }
766   }
767   if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
768     if (auto iType = dense.getType()
769                          .cast<TensorType>()
770                          .getElementType()
771                          .dyn_cast<IntegerType>()) {
772       os << '{';
773       interleaveComma(dense, os, [&](const APInt &val) {
774         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
775       });
776       os << '}';
777       return success();
778     }
779     if (auto iType = dense.getType()
780                          .cast<TensorType>()
781                          .getElementType()
782                          .dyn_cast<IndexType>()) {
783       os << '{';
784       interleaveComma(dense, os,
785                       [&](const APInt &val) { printInt(val, false); });
786       os << '}';
787       return success();
788     }
789   }
790 
791   // Print opaque attributes.
792   if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
793     os << oAttr.getValue();
794     return success();
795   }
796 
797   // Print symbolic reference attributes.
798   if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
799     if (sAttr.getNestedReferences().size() > 1)
800       return emitError(loc, "attribute has more than 1 nested reference");
801     os << sAttr.getRootReference().getValue();
802     return success();
803   }
804 
805   // Print type attributes.
806   if (auto type = attr.dyn_cast<TypeAttr>())
807     return emitType(loc, type.getValue());
808 
809   return emitError(loc, "cannot emit attribute of type ") << attr.getType();
810 }
811 
812 LogicalResult CppEmitter::emitOperands(Operation &op) {
813   auto emitOperandName = [&](Value result) -> LogicalResult {
814     if (!hasValueInScope(result))
815       return op.emitOpError() << "operand value not in scope";
816     os << getOrCreateName(result);
817     return success();
818   };
819   return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
820 }
821 
822 LogicalResult
823 CppEmitter::emitOperandsAndAttributes(Operation &op,
824                                       ArrayRef<StringRef> exclude) {
825   if (failed(emitOperands(op)))
826     return failure();
827   // Insert comma in between operands and non-filtered attributes if needed.
828   if (op.getNumOperands() > 0) {
829     for (NamedAttribute attr : op.getAttrs()) {
830       if (!llvm::is_contained(exclude, attr.getName().strref())) {
831         os << ", ";
832         break;
833       }
834     }
835   }
836   // Emit attributes.
837   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
838     if (llvm::is_contained(exclude, attr.getName().strref()))
839       return success();
840     os << "/* " << attr.getName().getValue() << " */";
841     if (failed(emitAttribute(op.getLoc(), attr.getValue())))
842       return failure();
843     return success();
844   };
845   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
846 }
847 
848 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
849   if (!hasValueInScope(result)) {
850     return result.getDefiningOp()->emitOpError(
851         "result variable for the operation has not been declared");
852   }
853   os << getOrCreateName(result) << " = ";
854   return success();
855 }
856 
857 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
858                                                   bool trailingSemicolon) {
859   if (hasValueInScope(result)) {
860     return result.getDefiningOp()->emitError(
861         "result variable for the operation already declared");
862   }
863   if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
864     return failure();
865   os << " " << getOrCreateName(result);
866   if (trailingSemicolon)
867     os << ";\n";
868   return success();
869 }
870 
871 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
872   switch (op.getNumResults()) {
873   case 0:
874     break;
875   case 1: {
876     OpResult result = op.getResult(0);
877     if (shouldDeclareVariablesAtTop()) {
878       if (failed(emitVariableAssignment(result)))
879         return failure();
880     } else {
881       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
882         return failure();
883       os << " = ";
884     }
885     break;
886   }
887   default:
888     if (!shouldDeclareVariablesAtTop()) {
889       for (OpResult result : op.getResults()) {
890         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
891           return failure();
892       }
893     }
894     os << "std::tie(";
895     interleaveComma(op.getResults(), os,
896                     [&](Value result) { os << getOrCreateName(result); });
897     os << ") = ";
898   }
899   return success();
900 }
901 
902 LogicalResult CppEmitter::emitLabel(Block &block) {
903   if (!hasBlockLabel(block))
904     return block.getParentOp()->emitError("label for block not found");
905   // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
906   // label instead of using `getOStream`.
907   os.getOStream() << getOrCreateName(block) << ":\n";
908   return success();
909 }
910 
911 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
912   LogicalResult status =
913       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
914           // Builtin ops.
915           .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
916           // CF ops.
917           .Case<cf::BranchOp, cf::CondBranchOp>(
918               [&](auto op) { return printOperation(*this, op); })
919           // EmitC ops.
920           .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
921                 emitc::IncludeOp, emitc::VariableOp>(
922               [&](auto op) { return printOperation(*this, op); })
923           // Func ops.
924           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
925               [&](auto op) { return printOperation(*this, op); })
926           // SCF ops.
927           .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
928               [&](auto op) { return printOperation(*this, op); })
929           // Arithmetic ops.
930           .Case<arith::ConstantOp>(
931               [&](auto op) { return printOperation(*this, op); })
932           .Default([&](Operation *) {
933             return op.emitOpError("unable to find printer for op");
934           });
935 
936   if (failed(status))
937     return failure();
938   os << (trailingSemicolon ? ";\n" : "\n");
939   return success();
940 }
941 
942 LogicalResult CppEmitter::emitType(Location loc, Type type) {
943   if (auto iType = type.dyn_cast<IntegerType>()) {
944     switch (iType.getWidth()) {
945     case 1:
946       return (os << "bool"), success();
947     case 8:
948     case 16:
949     case 32:
950     case 64:
951       if (shouldMapToUnsigned(iType.getSignedness()))
952         return (os << "uint" << iType.getWidth() << "_t"), success();
953       else
954         return (os << "int" << iType.getWidth() << "_t"), success();
955     default:
956       return emitError(loc, "cannot emit integer type ") << type;
957     }
958   }
959   if (auto fType = type.dyn_cast<FloatType>()) {
960     switch (fType.getWidth()) {
961     case 32:
962       return (os << "float"), success();
963     case 64:
964       return (os << "double"), success();
965     default:
966       return emitError(loc, "cannot emit float type ") << type;
967     }
968   }
969   if (auto iType = type.dyn_cast<IndexType>())
970     return (os << "size_t"), success();
971   if (auto tType = type.dyn_cast<TensorType>()) {
972     if (!tType.hasRank())
973       return emitError(loc, "cannot emit unranked tensor type");
974     if (!tType.hasStaticShape())
975       return emitError(loc, "cannot emit tensor type with non static shape");
976     os << "Tensor<";
977     if (failed(emitType(loc, tType.getElementType())))
978       return failure();
979     auto shape = tType.getShape();
980     for (auto dimSize : shape) {
981       os << ", ";
982       os << dimSize;
983     }
984     os << ">";
985     return success();
986   }
987   if (auto tType = type.dyn_cast<TupleType>())
988     return emitTupleType(loc, tType.getTypes());
989   if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
990     os << oType.getValue();
991     return success();
992   }
993   if (auto pType = type.dyn_cast<emitc::PointerType>()) {
994     if (failed(emitType(loc, pType.getPointee())))
995       return failure();
996     os << "*";
997     return success();
998   }
999   return emitError(loc, "cannot emit type ") << type;
1000 }
1001 
1002 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1003   switch (types.size()) {
1004   case 0:
1005     os << "void";
1006     return success();
1007   case 1:
1008     return emitType(loc, types.front());
1009   default:
1010     return emitTupleType(loc, types);
1011   }
1012 }
1013 
1014 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1015   os << "std::tuple<";
1016   if (failed(interleaveCommaWithError(
1017           types, os, [&](Type type) { return emitType(loc, type); })))
1018     return failure();
1019   os << ">";
1020   return success();
1021 }
1022 
1023 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1024                                     bool declareVariablesAtTop) {
1025   CppEmitter emitter(os, declareVariablesAtTop);
1026   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1027 }
1028