1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/EmitC/IR/EmitC.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/IndentedOstream.h"
19 #include "mlir/Target/Cpp/CppEmitter.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define DEBUG_TYPE "translate-to-cpp"
28 
29 using namespace mlir;
30 using namespace mlir::emitc;
31 using llvm::formatv;
32 
33 /// Convenience functions to produce interleaved output with functions returning
34 /// a LogicalResult. This is different than those in STLExtras as functions used
35 /// on each element doesn't return a string.
36 template <typename ForwardIterator, typename UnaryFunctor,
37           typename NullaryFunctor>
38 inline LogicalResult
39 interleaveWithError(ForwardIterator begin, ForwardIterator end,
40                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
41   if (begin == end)
42     return success();
43   if (failed(eachFn(*begin)))
44     return failure();
45   ++begin;
46   for (; begin != end; ++begin) {
47     betweenFn();
48     if (failed(eachFn(*begin)))
49       return failure();
50   }
51   return success();
52 }
53 
54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
55 inline LogicalResult interleaveWithError(const Container &c,
56                                          UnaryFunctor eachFn,
57                                          NullaryFunctor betweenFn) {
58   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
59 }
60 
61 template <typename Container, typename UnaryFunctor>
62 inline LogicalResult interleaveCommaWithError(const Container &c,
63                                               raw_ostream &os,
64                                               UnaryFunctor eachFn) {
65   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
66 }
67 
68 namespace {
69 /// Emitter that uses dialect specific emitters to emit C++ code.
70 struct CppEmitter {
71   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
72 
73   /// Emits attribute or returns failure.
74   LogicalResult emitAttribute(Location loc, Attribute attr);
75 
76   /// Emits operation 'op' with/without training semicolon or returns failure.
77   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
78 
79   /// Emits type 'type' or returns failure.
80   LogicalResult emitType(Location loc, Type type);
81 
82   /// Emits array of types as a std::tuple of the emitted types.
83   /// - emits void for an empty array;
84   /// - emits the type of the only element for arrays of size one;
85   /// - emits a std::tuple otherwise;
86   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
87 
88   /// Emits array of types as a std::tuple of the emitted types independently of
89   /// the array size.
90   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
91 
92   /// Emits an assignment for a variable which has been declared previously.
93   LogicalResult emitVariableAssignment(OpResult result);
94 
95   /// Emits a variable declaration for a result of an operation.
96   LogicalResult emitVariableDeclaration(OpResult result,
97                                         bool trailingSemicolon);
98 
99   /// Emits the variable declaration and assignment prefix for 'op'.
100   /// - emits separate variable followed by std::tie for multi-valued operation;
101   /// - emits single type followed by variable for single result;
102   /// - emits nothing if no value produced by op;
103   /// Emits final '=' operator where a type is produced. Returns failure if
104   /// any result type could not be converted.
105   LogicalResult emitAssignPrefix(Operation &op);
106 
107   /// Emits a label for the block.
108   LogicalResult emitLabel(Block &block);
109 
110   /// Emits the operands and atttributes of the operation. All operands are
111   /// emitted first and then all attributes in alphabetical order.
112   LogicalResult emitOperandsAndAttributes(Operation &op,
113                                           ArrayRef<StringRef> exclude = {});
114 
115   /// Emits the operands of the operation. All operands are emitted in order.
116   LogicalResult emitOperands(Operation &op);
117 
118   /// Return the existing or a new name for a Value.
119   StringRef getOrCreateName(Value val);
120 
121   /// Return the existing or a new label of a Block.
122   StringRef getOrCreateName(Block &block);
123 
124   /// Whether to map an mlir integer to a unsigned integer in C++.
125   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
126 
127   /// RAII helper function to manage entering/exiting C++ scopes.
128   struct Scope {
129     Scope(CppEmitter &emitter)
130         : valueMapperScope(emitter.valueMapper),
131           blockMapperScope(emitter.blockMapper), emitter(emitter) {
132       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
133       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
134     }
135     ~Scope() {
136       emitter.valueInScopeCount.pop();
137       emitter.labelInScopeCount.pop();
138     }
139 
140   private:
141     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
142     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
143     CppEmitter &emitter;
144   };
145 
146   /// Returns wether the Value is assigned to a C++ variable in the scope.
147   bool hasValueInScope(Value val);
148 
149   // Returns whether a label is assigned to the block.
150   bool hasBlockLabel(Block &block);
151 
152   /// Returns the output stream.
153   raw_indented_ostream &ostream() { return os; };
154 
155   /// Returns if all variables for op results and basic block arguments need to
156   /// be declared at the beginning of a function.
157   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
158 
159 private:
160   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
161   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
162 
163   /// Output stream to emit to.
164   raw_indented_ostream os;
165 
166   /// Boolean to enforce that all variables for op results and block
167   /// arguments are declared at the beginning of the function. This also
168   /// includes results from ops located in nested regions.
169   bool declareVariablesAtTop;
170 
171   /// Map from value to name of C++ variable that contain the name.
172   ValueMapper valueMapper;
173 
174   /// Map from block to name of C++ label.
175   BlockMapper blockMapper;
176 
177   /// The number of values in the current scope. This is used to declare the
178   /// names of values in a scope.
179   std::stack<int64_t> valueInScopeCount;
180   std::stack<int64_t> labelInScopeCount;
181 };
182 } // namespace
183 
184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
185                                      Attribute value) {
186   OpResult result = operation->getResult(0);
187 
188   // Only emit an assignment as the variable was already declared when printing
189   // the FuncOp.
190   if (emitter.shouldDeclareVariablesAtTop()) {
191     // Skip the assignment if the emitc.constant has no value.
192     if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
193       if (oAttr.getValue().empty())
194         return success();
195     }
196 
197     if (failed(emitter.emitVariableAssignment(result)))
198       return failure();
199     return emitter.emitAttribute(operation->getLoc(), value);
200   }
201 
202   // Emit a variable declaration for an emitc.constant op without value.
203   if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
204     if (oAttr.getValue().empty())
205       // The semicolon gets printed by the emitOperation function.
206       return emitter.emitVariableDeclaration(result,
207                                              /*trailingSemicolon=*/false);
208   }
209 
210   // Emit a variable declaration.
211   if (failed(emitter.emitAssignPrefix(*operation)))
212     return failure();
213   return emitter.emitAttribute(operation->getLoc(), value);
214 }
215 
216 static LogicalResult printOperation(CppEmitter &emitter,
217                                     emitc::ConstantOp constantOp) {
218   Operation *operation = constantOp.getOperation();
219   Attribute value = constantOp.value();
220 
221   return printConstantOp(emitter, operation, value);
222 }
223 
224 static LogicalResult printOperation(CppEmitter &emitter,
225                                     arith::ConstantOp constantOp) {
226   Operation *operation = constantOp.getOperation();
227   Attribute value = constantOp.getValue();
228 
229   return printConstantOp(emitter, operation, value);
230 }
231 
232 static LogicalResult printOperation(CppEmitter &emitter,
233                                     mlir::ConstantOp constantOp) {
234   Operation *operation = constantOp.getOperation();
235   Attribute value = constantOp.getValue();
236 
237   return printConstantOp(emitter, operation, value);
238 }
239 
240 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
241   raw_ostream &os = emitter.ostream();
242   Block &successor = *branchOp.getSuccessor();
243 
244   for (auto pair :
245        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
246     Value &operand = std::get<0>(pair);
247     BlockArgument &argument = std::get<1>(pair);
248     os << emitter.getOrCreateName(argument) << " = "
249        << emitter.getOrCreateName(operand) << ";\n";
250   }
251 
252   os << "goto ";
253   if (!(emitter.hasBlockLabel(successor)))
254     return branchOp.emitOpError("unable to find label for successor block");
255   os << emitter.getOrCreateName(successor);
256   return success();
257 }
258 
259 static LogicalResult printOperation(CppEmitter &emitter,
260                                     CondBranchOp condBranchOp) {
261   raw_indented_ostream &os = emitter.ostream();
262   Block &trueSuccessor = *condBranchOp.getTrueDest();
263   Block &falseSuccessor = *condBranchOp.getFalseDest();
264 
265   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
266      << ") {\n";
267 
268   os.indent();
269 
270   // If condition is true.
271   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
272                              trueSuccessor.getArguments())) {
273     Value &operand = std::get<0>(pair);
274     BlockArgument &argument = std::get<1>(pair);
275     os << emitter.getOrCreateName(argument) << " = "
276        << emitter.getOrCreateName(operand) << ";\n";
277   }
278 
279   os << "goto ";
280   if (!(emitter.hasBlockLabel(trueSuccessor))) {
281     return condBranchOp.emitOpError("unable to find label for successor block");
282   }
283   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
284   os.unindent() << "} else {\n";
285   os.indent();
286   // If condition is false.
287   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
288                              falseSuccessor.getArguments())) {
289     Value &operand = std::get<0>(pair);
290     BlockArgument &argument = std::get<1>(pair);
291     os << emitter.getOrCreateName(argument) << " = "
292        << emitter.getOrCreateName(operand) << ";\n";
293   }
294 
295   os << "goto ";
296   if (!(emitter.hasBlockLabel(falseSuccessor))) {
297     return condBranchOp.emitOpError()
298            << "unable to find label for successor block";
299   }
300   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
301   os.unindent() << "}";
302   return success();
303 }
304 
305 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) {
306   if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
307     return failure();
308 
309   raw_ostream &os = emitter.ostream();
310   os << callOp.getCallee() << "(";
311   if (failed(emitter.emitOperands(*callOp.getOperation())))
312     return failure();
313   os << ")";
314   return success();
315 }
316 
317 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
318   raw_ostream &os = emitter.ostream();
319   Operation &op = *callOp.getOperation();
320 
321   if (failed(emitter.emitAssignPrefix(op)))
322     return failure();
323   os << callOp.callee();
324 
325   auto emitArgs = [&](Attribute attr) -> LogicalResult {
326     if (auto t = attr.dyn_cast<IntegerAttr>()) {
327       // Index attributes are treated specially as operand index.
328       if (t.getType().isIndex()) {
329         int64_t idx = t.getInt();
330         if ((idx < 0) || (idx >= op.getNumOperands()))
331           return op.emitOpError("invalid operand index");
332         if (!emitter.hasValueInScope(op.getOperand(idx)))
333           return op.emitOpError("operand ")
334                  << idx << "'s value not defined in scope";
335         os << emitter.getOrCreateName(op.getOperand(idx));
336         return success();
337       }
338     }
339     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
340       return failure();
341 
342     return success();
343   };
344 
345   if (callOp.template_args()) {
346     os << "<";
347     if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
348       return failure();
349     os << ">";
350   }
351 
352   os << "(";
353 
354   LogicalResult emittedArgs =
355       callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
356                     : emitter.emitOperands(op);
357   if (failed(emittedArgs))
358     return failure();
359   os << ")";
360   return success();
361 }
362 
363 static LogicalResult printOperation(CppEmitter &emitter,
364                                     emitc::ApplyOp applyOp) {
365   raw_ostream &os = emitter.ostream();
366   Operation &op = *applyOp.getOperation();
367 
368   if (failed(emitter.emitAssignPrefix(op)))
369     return failure();
370   os << applyOp.applicableOperator();
371   os << emitter.getOrCreateName(applyOp.getOperand());
372 
373   return success();
374 }
375 
376 static LogicalResult printOperation(CppEmitter &emitter,
377                                     emitc::IncludeOp includeOp) {
378   raw_ostream &os = emitter.ostream();
379 
380   os << "#include ";
381   if (includeOp.is_standard_include())
382     os << "<" << includeOp.include() << ">";
383   else
384     os << "\"" << includeOp.include() << "\"";
385 
386   return success();
387 }
388 
389 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
390 
391   raw_indented_ostream &os = emitter.ostream();
392 
393   OperandRange operands = forOp.getIterOperands();
394   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
395   Operation::result_range results = forOp.getResults();
396 
397   if (!emitter.shouldDeclareVariablesAtTop()) {
398     for (OpResult result : results) {
399       if (failed(emitter.emitVariableDeclaration(result,
400                                                  /*trailingSemicolon=*/true)))
401         return failure();
402     }
403   }
404 
405   for (auto pair : llvm::zip(iterArgs, operands)) {
406     if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
407       return failure();
408     os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
409     os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
410     os << "\n";
411   }
412 
413   os << "for (";
414   if (failed(
415           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
416     return failure();
417   os << " ";
418   os << emitter.getOrCreateName(forOp.getInductionVar());
419   os << " = ";
420   os << emitter.getOrCreateName(forOp.getLowerBound());
421   os << "; ";
422   os << emitter.getOrCreateName(forOp.getInductionVar());
423   os << " < ";
424   os << emitter.getOrCreateName(forOp.getUpperBound());
425   os << "; ";
426   os << emitter.getOrCreateName(forOp.getInductionVar());
427   os << " += ";
428   os << emitter.getOrCreateName(forOp.getStep());
429   os << ") {\n";
430   os.indent();
431 
432   Region &forRegion = forOp.getRegion();
433   auto regionOps = forRegion.getOps();
434 
435   // We skip the trailing yield op because this updates the result variables
436   // of the for op in the generated code. Instead we update the iterArgs at
437   // the end of a loop iteration and set the result variables after the for
438   // loop.
439   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
440     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
441       return failure();
442   }
443 
444   Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
445   // Copy yield operands into iterArgs at the end of a loop iteration.
446   for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
447     BlockArgument iterArg = std::get<0>(pair);
448     Value operand = std::get<1>(pair);
449     os << emitter.getOrCreateName(iterArg) << " = "
450        << emitter.getOrCreateName(operand) << ";\n";
451   }
452 
453   os.unindent() << "}";
454 
455   // Copy iterArgs into results after the for loop.
456   for (auto pair : llvm::zip(results, iterArgs)) {
457     OpResult result = std::get<0>(pair);
458     BlockArgument iterArg = std::get<1>(pair);
459     os << "\n"
460        << emitter.getOrCreateName(result) << " = "
461        << emitter.getOrCreateName(iterArg) << ";";
462   }
463 
464   return success();
465 }
466 
467 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
468   raw_indented_ostream &os = emitter.ostream();
469 
470   if (!emitter.shouldDeclareVariablesAtTop()) {
471     for (OpResult result : ifOp.getResults()) {
472       if (failed(emitter.emitVariableDeclaration(result,
473                                                  /*trailingSemicolon=*/true)))
474         return failure();
475     }
476   }
477 
478   os << "if (";
479   if (failed(emitter.emitOperands(*ifOp.getOperation())))
480     return failure();
481   os << ") {\n";
482   os.indent();
483 
484   Region &thenRegion = ifOp.getThenRegion();
485   for (Operation &op : thenRegion.getOps()) {
486     // Note: This prints a superfluous semicolon if the terminating yield op has
487     // zero results.
488     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
489       return failure();
490   }
491 
492   os.unindent() << "}";
493 
494   Region &elseRegion = ifOp.getElseRegion();
495   if (!elseRegion.empty()) {
496     os << " else {\n";
497     os.indent();
498 
499     for (Operation &op : elseRegion.getOps()) {
500       // Note: This prints a superfluous semicolon if the terminating yield op
501       // has zero results.
502       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
503         return failure();
504     }
505 
506     os.unindent() << "}";
507   }
508 
509   return success();
510 }
511 
512 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
513   raw_ostream &os = emitter.ostream();
514   Operation &parentOp = *yieldOp.getOperation()->getParentOp();
515 
516   if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
517     return yieldOp.emitError("number of operands does not to match the number "
518                              "of the parent op's results");
519   }
520 
521   if (failed(interleaveWithError(
522           llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
523           [&](auto pair) -> LogicalResult {
524             auto result = std::get<0>(pair);
525             auto operand = std::get<1>(pair);
526             os << emitter.getOrCreateName(result) << " = ";
527 
528             if (!emitter.hasValueInScope(operand))
529               return yieldOp.emitError("operand value not in scope");
530             os << emitter.getOrCreateName(operand);
531             return success();
532           },
533           [&]() { os << ";\n"; })))
534     return failure();
535 
536   return success();
537 }
538 
539 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) {
540   raw_ostream &os = emitter.ostream();
541   os << "return";
542   switch (returnOp.getNumOperands()) {
543   case 0:
544     return success();
545   case 1:
546     os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
547     return success(emitter.hasValueInScope(returnOp.getOperand(0)));
548   default:
549     os << " std::make_tuple(";
550     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
551       return failure();
552     os << ")";
553     return success();
554   }
555 }
556 
557 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
558   CppEmitter::Scope scope(emitter);
559 
560   for (Operation &op : moduleOp) {
561     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
562       return failure();
563   }
564   return success();
565 }
566 
567 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
568   // We need to declare variables at top if the function has multiple blocks.
569   if (!emitter.shouldDeclareVariablesAtTop() &&
570       functionOp.getBlocks().size() > 1) {
571     return functionOp.emitOpError(
572         "with multiple blocks needs variables declared at top");
573   }
574 
575   CppEmitter::Scope scope(emitter);
576   raw_indented_ostream &os = emitter.ostream();
577   if (failed(emitter.emitTypes(functionOp.getLoc(),
578                                functionOp.getType().getResults())))
579     return failure();
580   os << " " << functionOp.getName();
581 
582   os << "(";
583   if (failed(interleaveCommaWithError(
584           functionOp.getArguments(), os,
585           [&](BlockArgument arg) -> LogicalResult {
586             if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
587               return failure();
588             os << " " << emitter.getOrCreateName(arg);
589             return success();
590           })))
591     return failure();
592   os << ") {\n";
593   os.indent();
594   if (emitter.shouldDeclareVariablesAtTop()) {
595     // Declare all variables that hold op results including those from nested
596     // regions.
597     WalkResult result =
598         functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
599           for (OpResult result : op->getResults()) {
600             if (failed(emitter.emitVariableDeclaration(
601                     result, /*trailingSemicolon=*/true))) {
602               return WalkResult(
603                   op->emitError("unable to declare result variable for op"));
604             }
605           }
606           return WalkResult::advance();
607         });
608     if (result.wasInterrupted())
609       return failure();
610   }
611 
612   Region::BlockListType &blocks = functionOp.getBlocks();
613   // Create label names for basic blocks.
614   for (Block &block : blocks) {
615     emitter.getOrCreateName(block);
616   }
617 
618   // Declare variables for basic block arguments.
619   for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
620     Block &block = *it;
621     for (BlockArgument &arg : block.getArguments()) {
622       if (emitter.hasValueInScope(arg))
623         return functionOp.emitOpError(" block argument #")
624                << arg.getArgNumber() << " is out of scope";
625       if (failed(
626               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
627         return failure();
628       }
629       os << " " << emitter.getOrCreateName(arg) << ";\n";
630     }
631   }
632 
633   for (Block &block : blocks) {
634     // Only print a label if there is more than one block.
635     if (blocks.size() > 1) {
636       if (failed(emitter.emitLabel(block)))
637         return failure();
638     }
639     for (Operation &op : block.getOperations()) {
640       // When generating code for an scf.if or std.cond_br op no semicolon needs
641       // to be printed after the closing brace.
642       // When generating code for an scf.for op, printing a trailing semicolon
643       // is handled within the printOperation function.
644       bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
645 
646       if (failed(emitter.emitOperation(
647               op, /*trailingSemicolon=*/trailingSemicolon)))
648         return failure();
649     }
650   }
651   os.unindent() << "}\n";
652   return success();
653 }
654 
655 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
656     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
657   valueInScopeCount.push(0);
658   labelInScopeCount.push(0);
659 }
660 
661 /// Return the existing or a new name for a Value.
662 StringRef CppEmitter::getOrCreateName(Value val) {
663   if (!valueMapper.count(val))
664     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
665   return *valueMapper.begin(val);
666 }
667 
668 /// Return the existing or a new label for a Block.
669 StringRef CppEmitter::getOrCreateName(Block &block) {
670   if (!blockMapper.count(&block))
671     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
672   return *blockMapper.begin(&block);
673 }
674 
675 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
676   switch (val) {
677   case IntegerType::Signless:
678     return false;
679   case IntegerType::Signed:
680     return false;
681   case IntegerType::Unsigned:
682     return true;
683   }
684   llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
685 }
686 
687 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
688 
689 bool CppEmitter::hasBlockLabel(Block &block) {
690   return blockMapper.count(&block);
691 }
692 
693 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
694   auto printInt = [&](const APInt &val, bool isUnsigned) {
695     if (val.getBitWidth() == 1) {
696       if (val.getBoolValue())
697         os << "true";
698       else
699         os << "false";
700     } else {
701       SmallString<128> strValue;
702       val.toString(strValue, 10, !isUnsigned, false);
703       os << strValue;
704     }
705   };
706 
707   auto printFloat = [&](const APFloat &val) {
708     if (val.isFinite()) {
709       SmallString<128> strValue;
710       // Use default values of toString except don't truncate zeros.
711       val.toString(strValue, 0, 0, false);
712       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
713       case llvm::APFloatBase::S_IEEEsingle:
714         os << "(float)";
715         break;
716       case llvm::APFloatBase::S_IEEEdouble:
717         os << "(double)";
718         break;
719       default:
720         break;
721       };
722       os << strValue;
723     } else if (val.isNaN()) {
724       os << "NAN";
725     } else if (val.isInfinity()) {
726       if (val.isNegative())
727         os << "-";
728       os << "INFINITY";
729     }
730   };
731 
732   // Print floating point attributes.
733   if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
734     printFloat(fAttr.getValue());
735     return success();
736   }
737   if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
738     os << '{';
739     interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
740     os << '}';
741     return success();
742   }
743 
744   // Print integer attributes.
745   if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
746     if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
747       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
748       return success();
749     }
750     if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
751       printInt(iAttr.getValue(), false);
752       return success();
753     }
754   }
755   if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
756     if (auto iType = dense.getType()
757                          .cast<TensorType>()
758                          .getElementType()
759                          .dyn_cast<IntegerType>()) {
760       os << '{';
761       interleaveComma(dense, os, [&](const APInt &val) {
762         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
763       });
764       os << '}';
765       return success();
766     }
767     if (auto iType = dense.getType()
768                          .cast<TensorType>()
769                          .getElementType()
770                          .dyn_cast<IndexType>()) {
771       os << '{';
772       interleaveComma(dense, os,
773                       [&](const APInt &val) { printInt(val, false); });
774       os << '}';
775       return success();
776     }
777   }
778 
779   // Print opaque attributes.
780   if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
781     os << oAttr.getValue();
782     return success();
783   }
784 
785   // Print symbolic reference attributes.
786   if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
787     if (sAttr.getNestedReferences().size() > 1)
788       return emitError(loc, "attribute has more than 1 nested reference");
789     os << sAttr.getRootReference().getValue();
790     return success();
791   }
792 
793   // Print type attributes.
794   if (auto type = attr.dyn_cast<TypeAttr>())
795     return emitType(loc, type.getValue());
796 
797   return emitError(loc, "cannot emit attribute of type ") << attr.getType();
798 }
799 
800 LogicalResult CppEmitter::emitOperands(Operation &op) {
801   auto emitOperandName = [&](Value result) -> LogicalResult {
802     if (!hasValueInScope(result))
803       return op.emitOpError() << "operand value not in scope";
804     os << getOrCreateName(result);
805     return success();
806   };
807   return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
808 }
809 
810 LogicalResult
811 CppEmitter::emitOperandsAndAttributes(Operation &op,
812                                       ArrayRef<StringRef> exclude) {
813   if (failed(emitOperands(op)))
814     return failure();
815   // Insert comma in between operands and non-filtered attributes if needed.
816   if (op.getNumOperands() > 0) {
817     for (NamedAttribute attr : op.getAttrs()) {
818       if (!llvm::is_contained(exclude, attr.getName().strref())) {
819         os << ", ";
820         break;
821       }
822     }
823   }
824   // Emit attributes.
825   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
826     if (llvm::is_contained(exclude, attr.getName().strref()))
827       return success();
828     os << "/* " << attr.getName().getValue() << " */";
829     if (failed(emitAttribute(op.getLoc(), attr.getValue())))
830       return failure();
831     return success();
832   };
833   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
834 }
835 
836 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
837   if (!hasValueInScope(result)) {
838     return result.getDefiningOp()->emitOpError(
839         "result variable for the operation has not been declared");
840   }
841   os << getOrCreateName(result) << " = ";
842   return success();
843 }
844 
845 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
846                                                   bool trailingSemicolon) {
847   if (hasValueInScope(result)) {
848     return result.getDefiningOp()->emitError(
849         "result variable for the operation already declared");
850   }
851   if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
852     return failure();
853   os << " " << getOrCreateName(result);
854   if (trailingSemicolon)
855     os << ";\n";
856   return success();
857 }
858 
859 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
860   switch (op.getNumResults()) {
861   case 0:
862     break;
863   case 1: {
864     OpResult result = op.getResult(0);
865     if (shouldDeclareVariablesAtTop()) {
866       if (failed(emitVariableAssignment(result)))
867         return failure();
868     } else {
869       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
870         return failure();
871       os << " = ";
872     }
873     break;
874   }
875   default:
876     if (!shouldDeclareVariablesAtTop()) {
877       for (OpResult result : op.getResults()) {
878         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
879           return failure();
880       }
881     }
882     os << "std::tie(";
883     interleaveComma(op.getResults(), os,
884                     [&](Value result) { os << getOrCreateName(result); });
885     os << ") = ";
886   }
887   return success();
888 }
889 
890 LogicalResult CppEmitter::emitLabel(Block &block) {
891   if (!hasBlockLabel(block))
892     return block.getParentOp()->emitError("label for block not found");
893   // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
894   // label instead of using `getOStream`.
895   os.getOStream() << getOrCreateName(block) << ":\n";
896   return success();
897 }
898 
899 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
900   LogicalResult status =
901       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
902           // EmitC ops.
903           .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
904                 emitc::IncludeOp>(
905               [&](auto op) { return printOperation(*this, op); })
906           // SCF ops.
907           .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
908               [&](auto op) { return printOperation(*this, op); })
909           // Standard ops.
910           .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
911                 ModuleOp, ReturnOp>(
912               [&](auto op) { return printOperation(*this, op); })
913           // Arithmetic ops.
914           .Case<arith::ConstantOp>(
915               [&](auto op) { return printOperation(*this, op); })
916           .Default([&](Operation *) {
917             return op.emitOpError("unable to find printer for op");
918           });
919 
920   if (failed(status))
921     return failure();
922   os << (trailingSemicolon ? ";\n" : "\n");
923   return success();
924 }
925 
926 LogicalResult CppEmitter::emitType(Location loc, Type type) {
927   if (auto iType = type.dyn_cast<IntegerType>()) {
928     switch (iType.getWidth()) {
929     case 1:
930       return (os << "bool"), success();
931     case 8:
932     case 16:
933     case 32:
934     case 64:
935       if (shouldMapToUnsigned(iType.getSignedness()))
936         return (os << "uint" << iType.getWidth() << "_t"), success();
937       else
938         return (os << "int" << iType.getWidth() << "_t"), success();
939     default:
940       return emitError(loc, "cannot emit integer type ") << type;
941     }
942   }
943   if (auto fType = type.dyn_cast<FloatType>()) {
944     switch (fType.getWidth()) {
945     case 32:
946       return (os << "float"), success();
947     case 64:
948       return (os << "double"), success();
949     default:
950       return emitError(loc, "cannot emit float type ") << type;
951     }
952   }
953   if (auto iType = type.dyn_cast<IndexType>())
954     return (os << "size_t"), success();
955   if (auto tType = type.dyn_cast<TensorType>()) {
956     if (!tType.hasRank())
957       return emitError(loc, "cannot emit unranked tensor type");
958     if (!tType.hasStaticShape())
959       return emitError(loc, "cannot emit tensor type with non static shape");
960     os << "Tensor<";
961     if (failed(emitType(loc, tType.getElementType())))
962       return failure();
963     auto shape = tType.getShape();
964     for (auto dimSize : shape) {
965       os << ", ";
966       os << dimSize;
967     }
968     os << ">";
969     return success();
970   }
971   if (auto tType = type.dyn_cast<TupleType>())
972     return emitTupleType(loc, tType.getTypes());
973   if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
974     os << oType.getValue();
975     return success();
976   }
977   return emitError(loc, "cannot emit type ") << type;
978 }
979 
980 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
981   switch (types.size()) {
982   case 0:
983     os << "void";
984     return success();
985   case 1:
986     return emitType(loc, types.front());
987   default:
988     return emitTupleType(loc, types);
989   }
990 }
991 
992 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
993   os << "std::tuple<";
994   if (failed(interleaveCommaWithError(
995           types, os, [&](Type type) { return emitType(loc, type); })))
996     return failure();
997   os << ">";
998   return success();
999 }
1000 
1001 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1002                                     bool declareVariablesAtTop) {
1003   CppEmitter emitter(os, declareVariablesAtTop);
1004   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1005 }
1006