1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10 #include "mlir/Dialect/EmitC/IR/EmitC.h"
11 #include "mlir/Dialect/SCF/SCF.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Support/IndentedOstream.h"
18 #include "mlir/Target/Cpp/CppEmitter.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <utility>
26 
27 #define DEBUG_TYPE "translate-to-cpp"
28 
29 using namespace mlir;
30 using namespace mlir::emitc;
31 using llvm::formatv;
32 
33 /// Convenience functions to produce interleaved output with functions returning
34 /// a LogicalResult. This is different than those in STLExtras as functions used
35 /// on each element doesn't return a string.
36 template <typename ForwardIterator, typename UnaryFunctor,
37           typename NullaryFunctor>
38 inline LogicalResult
39 interleaveWithError(ForwardIterator begin, ForwardIterator end,
40                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
41   if (begin == end)
42     return success();
43   if (failed(eachFn(*begin)))
44     return failure();
45   ++begin;
46   for (; begin != end; ++begin) {
47     betweenFn();
48     if (failed(eachFn(*begin)))
49       return failure();
50   }
51   return success();
52 }
53 
54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
55 inline LogicalResult interleaveWithError(const Container &c,
56                                          UnaryFunctor eachFn,
57                                          NullaryFunctor betweenFn) {
58   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
59 }
60 
61 template <typename Container, typename UnaryFunctor>
62 inline LogicalResult interleaveCommaWithError(const Container &c,
63                                               raw_ostream &os,
64                                               UnaryFunctor eachFn) {
65   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
66 }
67 
68 namespace {
69 /// Emitter that uses dialect specific emitters to emit C++ code.
70 struct CppEmitter {
71   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
72 
73   /// Emits attribute or returns failure.
74   LogicalResult emitAttribute(Location loc, Attribute attr);
75 
76   /// Emits operation 'op' with/without training semicolon or returns failure.
77   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
78 
79   /// Emits type 'type' or returns failure.
80   LogicalResult emitType(Location loc, Type type);
81 
82   /// Emits array of types as a std::tuple of the emitted types.
83   /// - emits void for an empty array;
84   /// - emits the type of the only element for arrays of size one;
85   /// - emits a std::tuple otherwise;
86   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
87 
88   /// Emits array of types as a std::tuple of the emitted types independently of
89   /// the array size.
90   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
91 
92   /// Emits an assignment for a variable which has been declared previously.
93   LogicalResult emitVariableAssignment(OpResult result);
94 
95   /// Emits a variable declaration for a result of an operation.
96   LogicalResult emitVariableDeclaration(OpResult result,
97                                         bool trailingSemicolon);
98 
99   /// Emits the variable declaration and assignment prefix for 'op'.
100   /// - emits separate variable followed by std::tie for multi-valued operation;
101   /// - emits single type followed by variable for single result;
102   /// - emits nothing if no value produced by op;
103   /// Emits final '=' operator where a type is produced. Returns failure if
104   /// any result type could not be converted.
105   LogicalResult emitAssignPrefix(Operation &op);
106 
107   /// Emits a label for the block.
108   LogicalResult emitLabel(Block &block);
109 
110   /// Emits the operands and atttributes of the operation. All operands are
111   /// emitted first and then all attributes in alphabetical order.
112   LogicalResult emitOperandsAndAttributes(Operation &op,
113                                           ArrayRef<StringRef> exclude = {});
114 
115   /// Emits the operands of the operation. All operands are emitted in order.
116   LogicalResult emitOperands(Operation &op);
117 
118   /// Return the existing or a new name for a Value.
119   StringRef getOrCreateName(Value val);
120 
121   /// Return the existing or a new label of a Block.
122   StringRef getOrCreateName(Block &block);
123 
124   /// Whether to map an mlir integer to a unsigned integer in C++.
125   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
126 
127   /// RAII helper function to manage entering/exiting C++ scopes.
128   struct Scope {
129     Scope(CppEmitter &emitter)
130         : valueMapperScope(emitter.valueMapper),
131           blockMapperScope(emitter.blockMapper), emitter(emitter) {
132       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
133       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
134     }
135     ~Scope() {
136       emitter.valueInScopeCount.pop();
137       emitter.labelInScopeCount.pop();
138     }
139 
140   private:
141     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
142     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
143     CppEmitter &emitter;
144   };
145 
146   /// Returns wether the Value is assigned to a C++ variable in the scope.
147   bool hasValueInScope(Value val);
148 
149   // Returns whether a label is assigned to the block.
150   bool hasBlockLabel(Block &block);
151 
152   /// Returns the output stream.
153   raw_indented_ostream &ostream() { return os; };
154 
155   /// Returns if all variables for op results and basic block arguments need to
156   /// be declared at the beginning of a function.
157   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
158 
159 private:
160   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
161   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
162 
163   /// Output stream to emit to.
164   raw_indented_ostream os;
165 
166   /// Boolean to enforce that all variables for op results and block
167   /// arguments are declared at the beginning of the function. This also
168   /// includes results from ops located in nested regions.
169   bool declareVariablesAtTop;
170 
171   /// Map from value to name of C++ variable that contain the name.
172   ValueMapper valueMapper;
173 
174   /// Map from block to name of C++ label.
175   BlockMapper blockMapper;
176 
177   /// The number of values in the current scope. This is used to declare the
178   /// names of values in a scope.
179   std::stack<int64_t> valueInScopeCount;
180   std::stack<int64_t> labelInScopeCount;
181 };
182 } // namespace
183 
184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
185                                      Attribute value) {
186   OpResult result = operation->getResult(0);
187 
188   // Only emit an assignment as the variable was already declared when printing
189   // the FuncOp.
190   if (emitter.shouldDeclareVariablesAtTop()) {
191     // Skip the assignment if the emitc.constant has no value.
192     if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
193       if (oAttr.getValue().empty())
194         return success();
195     }
196 
197     if (failed(emitter.emitVariableAssignment(result)))
198       return failure();
199     return emitter.emitAttribute(operation->getLoc(), value);
200   }
201 
202   // Emit a variable declaration for an emitc.constant op without value.
203   if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
204     if (oAttr.getValue().empty())
205       // The semicolon gets printed by the emitOperation function.
206       return emitter.emitVariableDeclaration(result,
207                                              /*trailingSemicolon=*/false);
208   }
209 
210   // Emit a variable declaration.
211   if (failed(emitter.emitAssignPrefix(*operation)))
212     return failure();
213   return emitter.emitAttribute(operation->getLoc(), value);
214 }
215 
216 static LogicalResult printOperation(CppEmitter &emitter,
217                                     emitc::ConstantOp constantOp) {
218   Operation *operation = constantOp.getOperation();
219   Attribute value = constantOp.value();
220 
221   return printConstantOp(emitter, operation, value);
222 }
223 
224 static LogicalResult printOperation(CppEmitter &emitter,
225                                     arith::ConstantOp constantOp) {
226   Operation *operation = constantOp.getOperation();
227   Attribute value = constantOp.getValue();
228 
229   return printConstantOp(emitter, operation, value);
230 }
231 
232 static LogicalResult printOperation(CppEmitter &emitter,
233                                     mlir::ConstantOp constantOp) {
234   Operation *operation = constantOp.getOperation();
235   Attribute value = constantOp.getValueAttr();
236 
237   return printConstantOp(emitter, operation, value);
238 }
239 
240 static LogicalResult printOperation(CppEmitter &emitter,
241                                     cf::BranchOp branchOp) {
242   raw_ostream &os = emitter.ostream();
243   Block &successor = *branchOp.getSuccessor();
244 
245   for (auto pair :
246        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
247     Value &operand = std::get<0>(pair);
248     BlockArgument &argument = std::get<1>(pair);
249     os << emitter.getOrCreateName(argument) << " = "
250        << emitter.getOrCreateName(operand) << ";\n";
251   }
252 
253   os << "goto ";
254   if (!(emitter.hasBlockLabel(successor)))
255     return branchOp.emitOpError("unable to find label for successor block");
256   os << emitter.getOrCreateName(successor);
257   return success();
258 }
259 
260 static LogicalResult printOperation(CppEmitter &emitter,
261                                     cf::CondBranchOp condBranchOp) {
262   raw_indented_ostream &os = emitter.ostream();
263   Block &trueSuccessor = *condBranchOp.getTrueDest();
264   Block &falseSuccessor = *condBranchOp.getFalseDest();
265 
266   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
267      << ") {\n";
268 
269   os.indent();
270 
271   // If condition is true.
272   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
273                              trueSuccessor.getArguments())) {
274     Value &operand = std::get<0>(pair);
275     BlockArgument &argument = std::get<1>(pair);
276     os << emitter.getOrCreateName(argument) << " = "
277        << emitter.getOrCreateName(operand) << ";\n";
278   }
279 
280   os << "goto ";
281   if (!(emitter.hasBlockLabel(trueSuccessor))) {
282     return condBranchOp.emitOpError("unable to find label for successor block");
283   }
284   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
285   os.unindent() << "} else {\n";
286   os.indent();
287   // If condition is false.
288   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
289                              falseSuccessor.getArguments())) {
290     Value &operand = std::get<0>(pair);
291     BlockArgument &argument = std::get<1>(pair);
292     os << emitter.getOrCreateName(argument) << " = "
293        << emitter.getOrCreateName(operand) << ";\n";
294   }
295 
296   os << "goto ";
297   if (!(emitter.hasBlockLabel(falseSuccessor))) {
298     return condBranchOp.emitOpError()
299            << "unable to find label for successor block";
300   }
301   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
302   os.unindent() << "}";
303   return success();
304 }
305 
306 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) {
307   if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
308     return failure();
309 
310   raw_ostream &os = emitter.ostream();
311   os << callOp.getCallee() << "(";
312   if (failed(emitter.emitOperands(*callOp.getOperation())))
313     return failure();
314   os << ")";
315   return success();
316 }
317 
318 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
319   raw_ostream &os = emitter.ostream();
320   Operation &op = *callOp.getOperation();
321 
322   if (failed(emitter.emitAssignPrefix(op)))
323     return failure();
324   os << callOp.callee();
325 
326   auto emitArgs = [&](Attribute attr) -> LogicalResult {
327     if (auto t = attr.dyn_cast<IntegerAttr>()) {
328       // Index attributes are treated specially as operand index.
329       if (t.getType().isIndex()) {
330         int64_t idx = t.getInt();
331         if ((idx < 0) || (idx >= op.getNumOperands()))
332           return op.emitOpError("invalid operand index");
333         if (!emitter.hasValueInScope(op.getOperand(idx)))
334           return op.emitOpError("operand ")
335                  << idx << "'s value not defined in scope";
336         os << emitter.getOrCreateName(op.getOperand(idx));
337         return success();
338       }
339     }
340     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
341       return failure();
342 
343     return success();
344   };
345 
346   if (callOp.template_args()) {
347     os << "<";
348     if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
349       return failure();
350     os << ">";
351   }
352 
353   os << "(";
354 
355   LogicalResult emittedArgs =
356       callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
357                     : emitter.emitOperands(op);
358   if (failed(emittedArgs))
359     return failure();
360   os << ")";
361   return success();
362 }
363 
364 static LogicalResult printOperation(CppEmitter &emitter,
365                                     emitc::ApplyOp applyOp) {
366   raw_ostream &os = emitter.ostream();
367   Operation &op = *applyOp.getOperation();
368 
369   if (failed(emitter.emitAssignPrefix(op)))
370     return failure();
371   os << applyOp.applicableOperator();
372   os << emitter.getOrCreateName(applyOp.getOperand());
373 
374   return success();
375 }
376 
377 static LogicalResult printOperation(CppEmitter &emitter,
378                                     emitc::IncludeOp includeOp) {
379   raw_ostream &os = emitter.ostream();
380 
381   os << "#include ";
382   if (includeOp.is_standard_include())
383     os << "<" << includeOp.include() << ">";
384   else
385     os << "\"" << includeOp.include() << "\"";
386 
387   return success();
388 }
389 
390 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
391 
392   raw_indented_ostream &os = emitter.ostream();
393 
394   OperandRange operands = forOp.getIterOperands();
395   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
396   Operation::result_range results = forOp.getResults();
397 
398   if (!emitter.shouldDeclareVariablesAtTop()) {
399     for (OpResult result : results) {
400       if (failed(emitter.emitVariableDeclaration(result,
401                                                  /*trailingSemicolon=*/true)))
402         return failure();
403     }
404   }
405 
406   for (auto pair : llvm::zip(iterArgs, operands)) {
407     if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
408       return failure();
409     os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
410     os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
411     os << "\n";
412   }
413 
414   os << "for (";
415   if (failed(
416           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
417     return failure();
418   os << " ";
419   os << emitter.getOrCreateName(forOp.getInductionVar());
420   os << " = ";
421   os << emitter.getOrCreateName(forOp.getLowerBound());
422   os << "; ";
423   os << emitter.getOrCreateName(forOp.getInductionVar());
424   os << " < ";
425   os << emitter.getOrCreateName(forOp.getUpperBound());
426   os << "; ";
427   os << emitter.getOrCreateName(forOp.getInductionVar());
428   os << " += ";
429   os << emitter.getOrCreateName(forOp.getStep());
430   os << ") {\n";
431   os.indent();
432 
433   Region &forRegion = forOp.getRegion();
434   auto regionOps = forRegion.getOps();
435 
436   // We skip the trailing yield op because this updates the result variables
437   // of the for op in the generated code. Instead we update the iterArgs at
438   // the end of a loop iteration and set the result variables after the for
439   // loop.
440   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
441     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
442       return failure();
443   }
444 
445   Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
446   // Copy yield operands into iterArgs at the end of a loop iteration.
447   for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
448     BlockArgument iterArg = std::get<0>(pair);
449     Value operand = std::get<1>(pair);
450     os << emitter.getOrCreateName(iterArg) << " = "
451        << emitter.getOrCreateName(operand) << ";\n";
452   }
453 
454   os.unindent() << "}";
455 
456   // Copy iterArgs into results after the for loop.
457   for (auto pair : llvm::zip(results, iterArgs)) {
458     OpResult result = std::get<0>(pair);
459     BlockArgument iterArg = std::get<1>(pair);
460     os << "\n"
461        << emitter.getOrCreateName(result) << " = "
462        << emitter.getOrCreateName(iterArg) << ";";
463   }
464 
465   return success();
466 }
467 
468 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
469   raw_indented_ostream &os = emitter.ostream();
470 
471   if (!emitter.shouldDeclareVariablesAtTop()) {
472     for (OpResult result : ifOp.getResults()) {
473       if (failed(emitter.emitVariableDeclaration(result,
474                                                  /*trailingSemicolon=*/true)))
475         return failure();
476     }
477   }
478 
479   os << "if (";
480   if (failed(emitter.emitOperands(*ifOp.getOperation())))
481     return failure();
482   os << ") {\n";
483   os.indent();
484 
485   Region &thenRegion = ifOp.getThenRegion();
486   for (Operation &op : thenRegion.getOps()) {
487     // Note: This prints a superfluous semicolon if the terminating yield op has
488     // zero results.
489     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
490       return failure();
491   }
492 
493   os.unindent() << "}";
494 
495   Region &elseRegion = ifOp.getElseRegion();
496   if (!elseRegion.empty()) {
497     os << " else {\n";
498     os.indent();
499 
500     for (Operation &op : elseRegion.getOps()) {
501       // Note: This prints a superfluous semicolon if the terminating yield op
502       // has zero results.
503       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
504         return failure();
505     }
506 
507     os.unindent() << "}";
508   }
509 
510   return success();
511 }
512 
513 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
514   raw_ostream &os = emitter.ostream();
515   Operation &parentOp = *yieldOp.getOperation()->getParentOp();
516 
517   if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
518     return yieldOp.emitError("number of operands does not to match the number "
519                              "of the parent op's results");
520   }
521 
522   if (failed(interleaveWithError(
523           llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
524           [&](auto pair) -> LogicalResult {
525             auto result = std::get<0>(pair);
526             auto operand = std::get<1>(pair);
527             os << emitter.getOrCreateName(result) << " = ";
528 
529             if (!emitter.hasValueInScope(operand))
530               return yieldOp.emitError("operand value not in scope");
531             os << emitter.getOrCreateName(operand);
532             return success();
533           },
534           [&]() { os << ";\n"; })))
535     return failure();
536 
537   return success();
538 }
539 
540 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) {
541   raw_ostream &os = emitter.ostream();
542   os << "return";
543   switch (returnOp.getNumOperands()) {
544   case 0:
545     return success();
546   case 1:
547     os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
548     return success(emitter.hasValueInScope(returnOp.getOperand(0)));
549   default:
550     os << " std::make_tuple(";
551     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
552       return failure();
553     os << ")";
554     return success();
555   }
556 }
557 
558 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
559   CppEmitter::Scope scope(emitter);
560 
561   for (Operation &op : moduleOp) {
562     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
563       return failure();
564   }
565   return success();
566 }
567 
568 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
569   // We need to declare variables at top if the function has multiple blocks.
570   if (!emitter.shouldDeclareVariablesAtTop() &&
571       functionOp.getBlocks().size() > 1) {
572     return functionOp.emitOpError(
573         "with multiple blocks needs variables declared at top");
574   }
575 
576   CppEmitter::Scope scope(emitter);
577   raw_indented_ostream &os = emitter.ostream();
578   if (failed(emitter.emitTypes(functionOp.getLoc(),
579                                functionOp.getType().getResults())))
580     return failure();
581   os << " " << functionOp.getName();
582 
583   os << "(";
584   if (failed(interleaveCommaWithError(
585           functionOp.getArguments(), os,
586           [&](BlockArgument arg) -> LogicalResult {
587             if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
588               return failure();
589             os << " " << emitter.getOrCreateName(arg);
590             return success();
591           })))
592     return failure();
593   os << ") {\n";
594   os.indent();
595   if (emitter.shouldDeclareVariablesAtTop()) {
596     // Declare all variables that hold op results including those from nested
597     // regions.
598     WalkResult result =
599         functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
600           for (OpResult result : op->getResults()) {
601             if (failed(emitter.emitVariableDeclaration(
602                     result, /*trailingSemicolon=*/true))) {
603               return WalkResult(
604                   op->emitError("unable to declare result variable for op"));
605             }
606           }
607           return WalkResult::advance();
608         });
609     if (result.wasInterrupted())
610       return failure();
611   }
612 
613   Region::BlockListType &blocks = functionOp.getBlocks();
614   // Create label names for basic blocks.
615   for (Block &block : blocks) {
616     emitter.getOrCreateName(block);
617   }
618 
619   // Declare variables for basic block arguments.
620   for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
621     Block &block = *it;
622     for (BlockArgument &arg : block.getArguments()) {
623       if (emitter.hasValueInScope(arg))
624         return functionOp.emitOpError(" block argument #")
625                << arg.getArgNumber() << " is out of scope";
626       if (failed(
627               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
628         return failure();
629       }
630       os << " " << emitter.getOrCreateName(arg) << ";\n";
631     }
632   }
633 
634   for (Block &block : blocks) {
635     // Only print a label if the block has predecessors.
636     if (!block.hasNoPredecessors()) {
637       if (failed(emitter.emitLabel(block)))
638         return failure();
639     }
640     for (Operation &op : block.getOperations()) {
641       // When generating code for an scf.if or cf.cond_br op no semicolon needs
642       // to be printed after the closing brace.
643       // When generating code for an scf.for op, printing a trailing semicolon
644       // is handled within the printOperation function.
645       bool trailingSemicolon =
646           !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
647 
648       if (failed(emitter.emitOperation(
649               op, /*trailingSemicolon=*/trailingSemicolon)))
650         return failure();
651     }
652   }
653   os.unindent() << "}\n";
654   return success();
655 }
656 
657 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
658     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
659   valueInScopeCount.push(0);
660   labelInScopeCount.push(0);
661 }
662 
663 /// Return the existing or a new name for a Value.
664 StringRef CppEmitter::getOrCreateName(Value val) {
665   if (!valueMapper.count(val))
666     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
667   return *valueMapper.begin(val);
668 }
669 
670 /// Return the existing or a new label for a Block.
671 StringRef CppEmitter::getOrCreateName(Block &block) {
672   if (!blockMapper.count(&block))
673     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
674   return *blockMapper.begin(&block);
675 }
676 
677 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
678   switch (val) {
679   case IntegerType::Signless:
680     return false;
681   case IntegerType::Signed:
682     return false;
683   case IntegerType::Unsigned:
684     return true;
685   }
686   llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
687 }
688 
689 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
690 
691 bool CppEmitter::hasBlockLabel(Block &block) {
692   return blockMapper.count(&block);
693 }
694 
695 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
696   auto printInt = [&](const APInt &val, bool isUnsigned) {
697     if (val.getBitWidth() == 1) {
698       if (val.getBoolValue())
699         os << "true";
700       else
701         os << "false";
702     } else {
703       SmallString<128> strValue;
704       val.toString(strValue, 10, !isUnsigned, false);
705       os << strValue;
706     }
707   };
708 
709   auto printFloat = [&](const APFloat &val) {
710     if (val.isFinite()) {
711       SmallString<128> strValue;
712       // Use default values of toString except don't truncate zeros.
713       val.toString(strValue, 0, 0, false);
714       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
715       case llvm::APFloatBase::S_IEEEsingle:
716         os << "(float)";
717         break;
718       case llvm::APFloatBase::S_IEEEdouble:
719         os << "(double)";
720         break;
721       default:
722         break;
723       };
724       os << strValue;
725     } else if (val.isNaN()) {
726       os << "NAN";
727     } else if (val.isInfinity()) {
728       if (val.isNegative())
729         os << "-";
730       os << "INFINITY";
731     }
732   };
733 
734   // Print floating point attributes.
735   if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
736     printFloat(fAttr.getValue());
737     return success();
738   }
739   if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
740     os << '{';
741     interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
742     os << '}';
743     return success();
744   }
745 
746   // Print integer attributes.
747   if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
748     if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
749       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
750       return success();
751     }
752     if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
753       printInt(iAttr.getValue(), false);
754       return success();
755     }
756   }
757   if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
758     if (auto iType = dense.getType()
759                          .cast<TensorType>()
760                          .getElementType()
761                          .dyn_cast<IntegerType>()) {
762       os << '{';
763       interleaveComma(dense, os, [&](const APInt &val) {
764         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
765       });
766       os << '}';
767       return success();
768     }
769     if (auto iType = dense.getType()
770                          .cast<TensorType>()
771                          .getElementType()
772                          .dyn_cast<IndexType>()) {
773       os << '{';
774       interleaveComma(dense, os,
775                       [&](const APInt &val) { printInt(val, false); });
776       os << '}';
777       return success();
778     }
779   }
780 
781   // Print opaque attributes.
782   if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
783     os << oAttr.getValue();
784     return success();
785   }
786 
787   // Print symbolic reference attributes.
788   if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
789     if (sAttr.getNestedReferences().size() > 1)
790       return emitError(loc, "attribute has more than 1 nested reference");
791     os << sAttr.getRootReference().getValue();
792     return success();
793   }
794 
795   // Print type attributes.
796   if (auto type = attr.dyn_cast<TypeAttr>())
797     return emitType(loc, type.getValue());
798 
799   return emitError(loc, "cannot emit attribute of type ") << attr.getType();
800 }
801 
802 LogicalResult CppEmitter::emitOperands(Operation &op) {
803   auto emitOperandName = [&](Value result) -> LogicalResult {
804     if (!hasValueInScope(result))
805       return op.emitOpError() << "operand value not in scope";
806     os << getOrCreateName(result);
807     return success();
808   };
809   return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
810 }
811 
812 LogicalResult
813 CppEmitter::emitOperandsAndAttributes(Operation &op,
814                                       ArrayRef<StringRef> exclude) {
815   if (failed(emitOperands(op)))
816     return failure();
817   // Insert comma in between operands and non-filtered attributes if needed.
818   if (op.getNumOperands() > 0) {
819     for (NamedAttribute attr : op.getAttrs()) {
820       if (!llvm::is_contained(exclude, attr.getName().strref())) {
821         os << ", ";
822         break;
823       }
824     }
825   }
826   // Emit attributes.
827   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
828     if (llvm::is_contained(exclude, attr.getName().strref()))
829       return success();
830     os << "/* " << attr.getName().getValue() << " */";
831     if (failed(emitAttribute(op.getLoc(), attr.getValue())))
832       return failure();
833     return success();
834   };
835   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
836 }
837 
838 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
839   if (!hasValueInScope(result)) {
840     return result.getDefiningOp()->emitOpError(
841         "result variable for the operation has not been declared");
842   }
843   os << getOrCreateName(result) << " = ";
844   return success();
845 }
846 
847 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
848                                                   bool trailingSemicolon) {
849   if (hasValueInScope(result)) {
850     return result.getDefiningOp()->emitError(
851         "result variable for the operation already declared");
852   }
853   if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
854     return failure();
855   os << " " << getOrCreateName(result);
856   if (trailingSemicolon)
857     os << ";\n";
858   return success();
859 }
860 
861 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
862   switch (op.getNumResults()) {
863   case 0:
864     break;
865   case 1: {
866     OpResult result = op.getResult(0);
867     if (shouldDeclareVariablesAtTop()) {
868       if (failed(emitVariableAssignment(result)))
869         return failure();
870     } else {
871       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
872         return failure();
873       os << " = ";
874     }
875     break;
876   }
877   default:
878     if (!shouldDeclareVariablesAtTop()) {
879       for (OpResult result : op.getResults()) {
880         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
881           return failure();
882       }
883     }
884     os << "std::tie(";
885     interleaveComma(op.getResults(), os,
886                     [&](Value result) { os << getOrCreateName(result); });
887     os << ") = ";
888   }
889   return success();
890 }
891 
892 LogicalResult CppEmitter::emitLabel(Block &block) {
893   if (!hasBlockLabel(block))
894     return block.getParentOp()->emitError("label for block not found");
895   // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
896   // label instead of using `getOStream`.
897   os.getOStream() << getOrCreateName(block) << ":\n";
898   return success();
899 }
900 
901 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
902   LogicalResult status =
903       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
904           // EmitC ops.
905           .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
906                 emitc::IncludeOp>(
907               [&](auto op) { return printOperation(*this, op); })
908           // SCF ops.
909           .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
910               [&](auto op) { return printOperation(*this, op); })
911           // Standard ops.
912           .Case<cf::BranchOp, mlir::CallOp, cf::CondBranchOp, mlir::ConstantOp,
913                 FuncOp, ModuleOp, ReturnOp>(
914               [&](auto op) { return printOperation(*this, op); })
915           // Arithmetic ops.
916           .Case<arith::ConstantOp>(
917               [&](auto op) { return printOperation(*this, op); })
918           .Default([&](Operation *) {
919             return op.emitOpError("unable to find printer for op");
920           });
921 
922   if (failed(status))
923     return failure();
924   os << (trailingSemicolon ? ";\n" : "\n");
925   return success();
926 }
927 
928 LogicalResult CppEmitter::emitType(Location loc, Type type) {
929   if (auto iType = type.dyn_cast<IntegerType>()) {
930     switch (iType.getWidth()) {
931     case 1:
932       return (os << "bool"), success();
933     case 8:
934     case 16:
935     case 32:
936     case 64:
937       if (shouldMapToUnsigned(iType.getSignedness()))
938         return (os << "uint" << iType.getWidth() << "_t"), success();
939       else
940         return (os << "int" << iType.getWidth() << "_t"), success();
941     default:
942       return emitError(loc, "cannot emit integer type ") << type;
943     }
944   }
945   if (auto fType = type.dyn_cast<FloatType>()) {
946     switch (fType.getWidth()) {
947     case 32:
948       return (os << "float"), success();
949     case 64:
950       return (os << "double"), success();
951     default:
952       return emitError(loc, "cannot emit float type ") << type;
953     }
954   }
955   if (auto iType = type.dyn_cast<IndexType>())
956     return (os << "size_t"), success();
957   if (auto tType = type.dyn_cast<TensorType>()) {
958     if (!tType.hasRank())
959       return emitError(loc, "cannot emit unranked tensor type");
960     if (!tType.hasStaticShape())
961       return emitError(loc, "cannot emit tensor type with non static shape");
962     os << "Tensor<";
963     if (failed(emitType(loc, tType.getElementType())))
964       return failure();
965     auto shape = tType.getShape();
966     for (auto dimSize : shape) {
967       os << ", ";
968       os << dimSize;
969     }
970     os << ">";
971     return success();
972   }
973   if (auto tType = type.dyn_cast<TupleType>())
974     return emitTupleType(loc, tType.getTypes());
975   if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
976     os << oType.getValue();
977     return success();
978   }
979   return emitError(loc, "cannot emit type ") << type;
980 }
981 
982 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
983   switch (types.size()) {
984   case 0:
985     os << "void";
986     return success();
987   case 1:
988     return emitType(loc, types.front());
989   default:
990     return emitTupleType(loc, types);
991   }
992 }
993 
994 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
995   os << "std::tuple<";
996   if (failed(interleaveCommaWithError(
997           types, os, [&](Type type) { return emitType(loc, type); })))
998     return failure();
999   os << ">";
1000   return success();
1001 }
1002 
1003 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1004                                     bool declareVariablesAtTop) {
1005   CppEmitter emitter(os, declareVariablesAtTop);
1006   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1007 }
1008