1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/EmitC/IR/EmitC.h"
10 #include "mlir/Dialect/SCF/SCF.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/IndentedOstream.h"
17 #include "mlir/Target/Cpp/CppEmitter.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #define DEBUG_TYPE "translate-to-cpp"
26 
27 using namespace mlir;
28 using namespace mlir::emitc;
29 using llvm::formatv;
30 
31 /// Convenience functions to produce interleaved output with functions returning
32 /// a LogicalResult. This is different than those in STLExtras as functions used
33 /// on each element doesn't return a string.
34 template <typename ForwardIterator, typename UnaryFunctor,
35           typename NullaryFunctor>
36 inline LogicalResult
37 interleaveWithError(ForwardIterator begin, ForwardIterator end,
38                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
39   if (begin == end)
40     return success();
41   if (failed(eachFn(*begin)))
42     return failure();
43   ++begin;
44   for (; begin != end; ++begin) {
45     betweenFn();
46     if (failed(eachFn(*begin)))
47       return failure();
48   }
49   return success();
50 }
51 
52 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
53 inline LogicalResult interleaveWithError(const Container &c,
54                                          UnaryFunctor eachFn,
55                                          NullaryFunctor betweenFn) {
56   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
57 }
58 
59 template <typename Container, typename UnaryFunctor>
60 inline LogicalResult interleaveCommaWithError(const Container &c,
61                                               raw_ostream &os,
62                                               UnaryFunctor eachFn) {
63   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
64 }
65 
66 namespace {
67 /// Emitter that uses dialect specific emitters to emit C++ code.
68 struct CppEmitter {
69   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
70 
71   /// Emits attribute or returns failure.
72   LogicalResult emitAttribute(Location loc, Attribute attr);
73 
74   /// Emits operation 'op' with/without training semicolon or returns failure.
75   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
76 
77   /// Emits type 'type' or returns failure.
78   LogicalResult emitType(Location loc, Type type);
79 
80   /// Emits array of types as a std::tuple of the emitted types.
81   /// - emits void for an empty array;
82   /// - emits the type of the only element for arrays of size one;
83   /// - emits a std::tuple otherwise;
84   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
85 
86   /// Emits array of types as a std::tuple of the emitted types independently of
87   /// the array size.
88   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
89 
90   /// Emits an assignment for a variable which has been declared previously.
91   LogicalResult emitVariableAssignment(OpResult result);
92 
93   /// Emits a variable declaration for a result of an operation.
94   LogicalResult emitVariableDeclaration(OpResult result,
95                                         bool trailingSemicolon);
96 
97   /// Emits the variable declaration and assignment prefix for 'op'.
98   /// - emits separate variable followed by std::tie for multi-valued operation;
99   /// - emits single type followed by variable for single result;
100   /// - emits nothing if no value produced by op;
101   /// Emits final '=' operator where a type is produced. Returns failure if
102   /// any result type could not be converted.
103   LogicalResult emitAssignPrefix(Operation &op);
104 
105   /// Emits a label for the block.
106   LogicalResult emitLabel(Block &block);
107 
108   /// Emits the operands and atttributes of the operation. All operands are
109   /// emitted first and then all attributes in alphabetical order.
110   LogicalResult emitOperandsAndAttributes(Operation &op,
111                                           ArrayRef<StringRef> exclude = {});
112 
113   /// Emits the operands of the operation. All operands are emitted in order.
114   LogicalResult emitOperands(Operation &op);
115 
116   /// Return the existing or a new name for a Value.
117   StringRef getOrCreateName(Value val);
118 
119   /// Return the existing or a new label of a Block.
120   StringRef getOrCreateName(Block &block);
121 
122   /// Whether to map an mlir integer to a unsigned integer in C++.
123   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
124 
125   /// RAII helper function to manage entering/exiting C++ scopes.
126   struct Scope {
127     Scope(CppEmitter &emitter)
128         : valueMapperScope(emitter.valueMapper),
129           blockMapperScope(emitter.blockMapper), emitter(emitter) {
130       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
131       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
132     }
133     ~Scope() {
134       emitter.valueInScopeCount.pop();
135       emitter.labelInScopeCount.pop();
136     }
137 
138   private:
139     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
140     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
141     CppEmitter &emitter;
142   };
143 
144   /// Returns wether the Value is assigned to a C++ variable in the scope.
145   bool hasValueInScope(Value val);
146 
147   // Returns whether a label is assigned to the block.
148   bool hasBlockLabel(Block &block);
149 
150   /// Returns the output stream.
151   raw_indented_ostream &ostream() { return os; };
152 
153   /// Returns if all variables for op results and basic block arguments need to
154   /// be declared at the beginning of a function.
155   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
156 
157 private:
158   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
159   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
160 
161   /// Output stream to emit to.
162   raw_indented_ostream os;
163 
164   /// Boolean to enforce that all variables for op results and block
165   /// arguments are declared at the beginning of the function. This also
166   /// includes results from ops located in nested regions.
167   bool declareVariablesAtTop;
168 
169   /// Map from value to name of C++ variable that contain the name.
170   ValueMapper valueMapper;
171 
172   /// Map from block to name of C++ label.
173   BlockMapper blockMapper;
174 
175   /// The number of values in the current scope. This is used to declare the
176   /// names of values in a scope.
177   std::stack<int64_t> valueInScopeCount;
178   std::stack<int64_t> labelInScopeCount;
179 };
180 } // namespace
181 
182 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
183                                      Attribute value) {
184   OpResult result = operation->getResult(0);
185 
186   // Only emit an assignment as the variable was already declared when printing
187   // the FuncOp.
188   if (emitter.shouldDeclareVariablesAtTop()) {
189     // Skip the assignment if the emitc.constant has no value.
190     if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
191       if (oAttr.getValue().empty())
192         return success();
193     }
194 
195     if (failed(emitter.emitVariableAssignment(result)))
196       return failure();
197     return emitter.emitAttribute(operation->getLoc(), value);
198   }
199 
200   // Emit a variable declaration for an emitc.constant op without value.
201   if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
202     if (oAttr.getValue().empty())
203       // The semicolon gets printed by the emitOperation function.
204       return emitter.emitVariableDeclaration(result,
205                                              /*trailingSemicolon=*/false);
206   }
207 
208   // Emit a variable declaration.
209   if (failed(emitter.emitAssignPrefix(*operation)))
210     return failure();
211   return emitter.emitAttribute(operation->getLoc(), value);
212 }
213 
214 static LogicalResult printOperation(CppEmitter &emitter,
215                                     emitc::ConstantOp constantOp) {
216   Operation *operation = constantOp.getOperation();
217   Attribute value = constantOp.value();
218 
219   return printConstantOp(emitter, operation, value);
220 }
221 
222 static LogicalResult printOperation(CppEmitter &emitter,
223                                     mlir::ConstantOp constantOp) {
224   Operation *operation = constantOp.getOperation();
225   Attribute value = constantOp.value();
226 
227   return printConstantOp(emitter, operation, value);
228 }
229 
230 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
231   raw_ostream &os = emitter.ostream();
232   Block &successor = *branchOp.getSuccessor();
233 
234   for (auto pair :
235        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
236     Value &operand = std::get<0>(pair);
237     BlockArgument &argument = std::get<1>(pair);
238     os << emitter.getOrCreateName(argument) << " = "
239        << emitter.getOrCreateName(operand) << ";\n";
240   }
241 
242   os << "goto ";
243   if (!(emitter.hasBlockLabel(successor)))
244     return branchOp.emitOpError("unable to find label for successor block");
245   os << emitter.getOrCreateName(successor);
246   return success();
247 }
248 
249 static LogicalResult printOperation(CppEmitter &emitter,
250                                     CondBranchOp condBranchOp) {
251   raw_ostream &os = emitter.ostream();
252   Block &trueSuccessor = *condBranchOp.getTrueDest();
253   Block &falseSuccessor = *condBranchOp.getFalseDest();
254 
255   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
256      << ") {\n";
257 
258   // If condition is true.
259   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
260                              trueSuccessor.getArguments())) {
261     Value &operand = std::get<0>(pair);
262     BlockArgument &argument = std::get<1>(pair);
263     os << emitter.getOrCreateName(argument) << " = "
264        << emitter.getOrCreateName(operand) << ";\n";
265   }
266 
267   os << "goto ";
268   if (!(emitter.hasBlockLabel(trueSuccessor))) {
269     return condBranchOp.emitOpError("unable to find label for successor block");
270   }
271   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
272   os << "} else {\n";
273   // If condition is false.
274   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
275                              falseSuccessor.getArguments())) {
276     Value &operand = std::get<0>(pair);
277     BlockArgument &argument = std::get<1>(pair);
278     os << emitter.getOrCreateName(argument) << " = "
279        << emitter.getOrCreateName(operand) << ";\n";
280   }
281 
282   os << "goto ";
283   if (!(emitter.hasBlockLabel(falseSuccessor))) {
284     return condBranchOp.emitOpError()
285            << "unable to find label for successor block";
286   }
287   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
288   os << "}";
289   return success();
290 }
291 
292 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) {
293   if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
294     return failure();
295 
296   raw_ostream &os = emitter.ostream();
297   os << callOp.getCallee() << "(";
298   if (failed(emitter.emitOperands(*callOp.getOperation())))
299     return failure();
300   os << ")";
301   return success();
302 }
303 
304 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
305   raw_ostream &os = emitter.ostream();
306   Operation &op = *callOp.getOperation();
307 
308   if (failed(emitter.emitAssignPrefix(op)))
309     return failure();
310   os << callOp.callee();
311 
312   auto emitArgs = [&](Attribute attr) -> LogicalResult {
313     if (auto t = attr.dyn_cast<IntegerAttr>()) {
314       // Index attributes are treated specially as operand index.
315       if (t.getType().isIndex()) {
316         int64_t idx = t.getInt();
317         if ((idx < 0) || (idx >= op.getNumOperands()))
318           return op.emitOpError("invalid operand index");
319         if (!emitter.hasValueInScope(op.getOperand(idx)))
320           return op.emitOpError("operand ")
321                  << idx << "'s value not defined in scope";
322         os << emitter.getOrCreateName(op.getOperand(idx));
323         return success();
324       }
325     }
326     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
327       return failure();
328 
329     return success();
330   };
331 
332   if (callOp.template_args()) {
333     os << "<";
334     if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
335       return failure();
336     os << ">";
337   }
338 
339   os << "(";
340 
341   LogicalResult emittedArgs =
342       callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
343                     : emitter.emitOperands(op);
344   if (failed(emittedArgs))
345     return failure();
346   os << ")";
347   return success();
348 }
349 
350 static LogicalResult printOperation(CppEmitter &emitter,
351                                     emitc::ApplyOp applyOp) {
352   raw_ostream &os = emitter.ostream();
353   Operation &op = *applyOp.getOperation();
354 
355   if (failed(emitter.emitAssignPrefix(op)))
356     return failure();
357   os << applyOp.applicableOperator();
358   os << emitter.getOrCreateName(applyOp.getOperand());
359 
360   return success();
361 }
362 
363 static LogicalResult printOperation(CppEmitter &emitter,
364                                     emitc::IncludeOp includeOp) {
365   raw_ostream &os = emitter.ostream();
366 
367   os << "#include ";
368   if (includeOp.is_standard_include())
369     os << "<" << includeOp.include() << ">";
370   else
371     os << "\"" << includeOp.include() << "\"";
372 
373   return success();
374 }
375 
376 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
377 
378   raw_indented_ostream &os = emitter.ostream();
379 
380   OperandRange operands = forOp.getIterOperands();
381   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
382   Operation::result_range results = forOp.getResults();
383 
384   if (!emitter.shouldDeclareVariablesAtTop()) {
385     for (OpResult result : results) {
386       if (failed(emitter.emitVariableDeclaration(result,
387                                                  /*trailingSemicolon=*/true)))
388         return failure();
389     }
390   }
391 
392   for (auto pair : llvm::zip(iterArgs, operands)) {
393     if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
394       return failure();
395     os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
396     os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
397     os << "\n";
398   }
399 
400   os << "for (";
401   if (failed(
402           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
403     return failure();
404   os << " ";
405   os << emitter.getOrCreateName(forOp.getInductionVar());
406   os << " = ";
407   os << emitter.getOrCreateName(forOp.lowerBound());
408   os << "; ";
409   os << emitter.getOrCreateName(forOp.getInductionVar());
410   os << " < ";
411   os << emitter.getOrCreateName(forOp.upperBound());
412   os << "; ";
413   os << emitter.getOrCreateName(forOp.getInductionVar());
414   os << " += ";
415   os << emitter.getOrCreateName(forOp.step());
416   os << ") {\n";
417   os.indent();
418 
419   Region &forRegion = forOp.region();
420   auto regionOps = forRegion.getOps();
421 
422   // We skip the trailing yield op because this updates the result variables
423   // of the for op in the generated code. Instead we update the iterArgs at
424   // the end of a loop iteration and set the result variables after the for
425   // loop.
426   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
427     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
428       return failure();
429   }
430 
431   Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
432   // Copy yield operands into iterArgs at the end of a loop iteration.
433   for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
434     BlockArgument iterArg = std::get<0>(pair);
435     Value operand = std::get<1>(pair);
436     os << emitter.getOrCreateName(iterArg) << " = "
437        << emitter.getOrCreateName(operand) << ";\n";
438   }
439 
440   os.unindent() << "}";
441 
442   // Copy iterArgs into results after the for loop.
443   for (auto pair : llvm::zip(results, iterArgs)) {
444     OpResult result = std::get<0>(pair);
445     BlockArgument iterArg = std::get<1>(pair);
446     os << "\n"
447        << emitter.getOrCreateName(result) << " = "
448        << emitter.getOrCreateName(iterArg) << ";";
449   }
450 
451   return success();
452 }
453 
454 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
455   raw_indented_ostream &os = emitter.ostream();
456 
457   if (!emitter.shouldDeclareVariablesAtTop()) {
458     for (OpResult result : ifOp.getResults()) {
459       if (failed(emitter.emitVariableDeclaration(result,
460                                                  /*trailingSemicolon=*/true)))
461         return failure();
462     }
463   }
464 
465   os << "if (";
466   if (failed(emitter.emitOperands(*ifOp.getOperation())))
467     return failure();
468   os << ") {\n";
469   os.indent();
470 
471   Region &thenRegion = ifOp.thenRegion();
472   for (Operation &op : thenRegion.getOps()) {
473     // Note: This prints a superfluous semicolon if the terminating yield op has
474     // zero results.
475     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
476       return failure();
477   }
478 
479   os.unindent() << "}";
480 
481   Region &elseRegion = ifOp.elseRegion();
482   if (!elseRegion.empty()) {
483     os << " else {\n";
484     os.indent();
485 
486     for (Operation &op : elseRegion.getOps()) {
487       // Note: This prints a superfluous semicolon if the terminating yield op
488       // has zero results.
489       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
490         return failure();
491     }
492 
493     os.unindent() << "}";
494   }
495 
496   return success();
497 }
498 
499 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
500   raw_ostream &os = emitter.ostream();
501   Operation &parentOp = *yieldOp.getOperation()->getParentOp();
502 
503   if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
504     return yieldOp.emitError("number of operands does not to match the number "
505                              "of the parent op's results");
506   }
507 
508   if (failed(interleaveWithError(
509           llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
510           [&](auto pair) -> LogicalResult {
511             auto result = std::get<0>(pair);
512             auto operand = std::get<1>(pair);
513             os << emitter.getOrCreateName(result) << " = ";
514 
515             if (!emitter.hasValueInScope(operand))
516               return yieldOp.emitError("operand value not in scope");
517             os << emitter.getOrCreateName(operand);
518             return success();
519           },
520           [&]() { os << ";\n"; })))
521     return failure();
522 
523   return success();
524 }
525 
526 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) {
527   raw_ostream &os = emitter.ostream();
528   os << "return";
529   switch (returnOp.getNumOperands()) {
530   case 0:
531     return success();
532   case 1:
533     os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
534     return success(emitter.hasValueInScope(returnOp.getOperand(0)));
535   default:
536     os << " std::make_tuple(";
537     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
538       return failure();
539     os << ")";
540     return success();
541   }
542 }
543 
544 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
545   CppEmitter::Scope scope(emitter);
546 
547   for (Operation &op : moduleOp) {
548     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
549       return failure();
550   }
551   return success();
552 }
553 
554 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
555   // We need to declare variables at top if the function has multiple blocks.
556   if (!emitter.shouldDeclareVariablesAtTop() &&
557       functionOp.getBlocks().size() > 1) {
558     return functionOp.emitOpError(
559         "with multiple blocks needs variables declared at top");
560   }
561 
562   CppEmitter::Scope scope(emitter);
563   raw_indented_ostream &os = emitter.ostream();
564   if (failed(emitter.emitTypes(functionOp.getLoc(),
565                                functionOp.getType().getResults())))
566     return failure();
567   os << " " << functionOp.getName();
568 
569   os << "(";
570   if (failed(interleaveCommaWithError(
571           functionOp.getArguments(), os,
572           [&](BlockArgument arg) -> LogicalResult {
573             if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
574               return failure();
575             os << " " << emitter.getOrCreateName(arg);
576             return success();
577           })))
578     return failure();
579   os << ") {\n";
580   os.indent();
581   if (emitter.shouldDeclareVariablesAtTop()) {
582     // Declare all variables that hold op results including those from nested
583     // regions.
584     WalkResult result =
585         functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
586           for (OpResult result : op->getResults()) {
587             if (failed(emitter.emitVariableDeclaration(
588                     result, /*trailingSemicolon=*/true))) {
589               return WalkResult(
590                   op->emitError("unable to declare result variable for op"));
591             }
592           }
593           return WalkResult::advance();
594         });
595     if (result.wasInterrupted())
596       return failure();
597   }
598 
599   Region::BlockListType &blocks = functionOp.getBlocks();
600   // Create label names for basic blocks.
601   for (Block &block : blocks) {
602     emitter.getOrCreateName(block);
603   }
604 
605   // Declare variables for basic block arguments.
606   for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
607     Block &block = *it;
608     for (BlockArgument &arg : block.getArguments()) {
609       if (emitter.hasValueInScope(arg))
610         return functionOp.emitOpError(" block argument #")
611                << arg.getArgNumber() << " is out of scope";
612       if (failed(
613               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
614         return failure();
615       }
616       os << " " << emitter.getOrCreateName(arg) << ";\n";
617     }
618   }
619 
620   for (Block &block : blocks) {
621     // Only print a label if there is more than one block.
622     if (blocks.size() > 1) {
623       if (failed(emitter.emitLabel(block)))
624         return failure();
625     }
626     for (Operation &op : block.getOperations()) {
627       // When generating code for an scf.if or std.cond_br op no semicolon needs
628       // to be printed after the closing brace.
629       // When generating code for an scf.for op, printing a trailing semicolon
630       // is handled within the printOperation function.
631       bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
632 
633       if (failed(emitter.emitOperation(
634               op, /*trailingSemicolon=*/trailingSemicolon)))
635         return failure();
636     }
637   }
638   os.unindent() << "}\n";
639   return success();
640 }
641 
642 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
643     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
644   valueInScopeCount.push(0);
645   labelInScopeCount.push(0);
646 }
647 
648 /// Return the existing or a new name for a Value.
649 StringRef CppEmitter::getOrCreateName(Value val) {
650   if (!valueMapper.count(val))
651     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
652   return *valueMapper.begin(val);
653 }
654 
655 /// Return the existing or a new label for a Block.
656 StringRef CppEmitter::getOrCreateName(Block &block) {
657   if (!blockMapper.count(&block))
658     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
659   return *blockMapper.begin(&block);
660 }
661 
662 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
663   switch (val) {
664   case IntegerType::Signless:
665     return false;
666   case IntegerType::Signed:
667     return false;
668   case IntegerType::Unsigned:
669     return true;
670   }
671 }
672 
673 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
674 
675 bool CppEmitter::hasBlockLabel(Block &block) {
676   return blockMapper.count(&block);
677 }
678 
679 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
680   auto printInt = [&](APInt val, bool isUnsigned) {
681     if (val.getBitWidth() == 1) {
682       if (val.getBoolValue())
683         os << "true";
684       else
685         os << "false";
686     } else {
687       SmallString<128> strValue;
688       val.toString(strValue, 10, !isUnsigned, false);
689       os << strValue;
690     }
691   };
692 
693   auto printFloat = [&](APFloat val) {
694     if (val.isFinite()) {
695       SmallString<128> strValue;
696       // Use default values of toString except don't truncate zeros.
697       val.toString(strValue, 0, 0, false);
698       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
699       case llvm::APFloatBase::S_IEEEsingle:
700         os << "(float)";
701         break;
702       case llvm::APFloatBase::S_IEEEdouble:
703         os << "(double)";
704         break;
705       default:
706         break;
707       };
708       os << strValue;
709     } else if (val.isNaN()) {
710       os << "NAN";
711     } else if (val.isInfinity()) {
712       if (val.isNegative())
713         os << "-";
714       os << "INFINITY";
715     }
716   };
717 
718   // Print floating point attributes.
719   if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
720     printFloat(fAttr.getValue());
721     return success();
722   }
723   if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
724     os << '{';
725     interleaveComma(dense, os, [&](APFloat val) { printFloat(val); });
726     os << '}';
727     return success();
728   }
729 
730   // Print integer attributes.
731   if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
732     if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
733       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
734       return success();
735     }
736     if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
737       printInt(iAttr.getValue(), false);
738       return success();
739     }
740   }
741   if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
742     if (auto iType = dense.getType()
743                          .cast<TensorType>()
744                          .getElementType()
745                          .dyn_cast<IntegerType>()) {
746       os << '{';
747       interleaveComma(dense, os, [&](APInt val) {
748         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
749       });
750       os << '}';
751       return success();
752     }
753     if (auto iType = dense.getType()
754                          .cast<TensorType>()
755                          .getElementType()
756                          .dyn_cast<IndexType>()) {
757       os << '{';
758       interleaveComma(dense, os, [&](APInt val) { printInt(val, false); });
759       os << '}';
760       return success();
761     }
762   }
763 
764   // Print opaque attributes.
765   if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
766     os << oAttr.getValue();
767     return success();
768   }
769 
770   // Print symbolic reference attributes.
771   if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
772     if (sAttr.getNestedReferences().size() > 1)
773       return emitError(loc, "attribute has more than 1 nested reference");
774     os << sAttr.getRootReference().getValue();
775     return success();
776   }
777 
778   // Print type attributes.
779   if (auto type = attr.dyn_cast<TypeAttr>())
780     return emitType(loc, type.getValue());
781 
782   return emitError(loc, "cannot emit attribute of type ") << attr.getType();
783 }
784 
785 LogicalResult CppEmitter::emitOperands(Operation &op) {
786   auto emitOperandName = [&](Value result) -> LogicalResult {
787     if (!hasValueInScope(result))
788       return op.emitOpError() << "operand value not in scope";
789     os << getOrCreateName(result);
790     return success();
791   };
792   return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
793 }
794 
795 LogicalResult
796 CppEmitter::emitOperandsAndAttributes(Operation &op,
797                                       ArrayRef<StringRef> exclude) {
798   if (failed(emitOperands(op)))
799     return failure();
800   // Insert comma in between operands and non-filtered attributes if needed.
801   if (op.getNumOperands() > 0) {
802     for (NamedAttribute attr : op.getAttrs()) {
803       if (!llvm::is_contained(exclude, attr.first.strref())) {
804         os << ", ";
805         break;
806       }
807     }
808   }
809   // Emit attributes.
810   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
811     if (llvm::is_contained(exclude, attr.first.strref()))
812       return success();
813     os << "/* " << attr.first << " */";
814     if (failed(emitAttribute(op.getLoc(), attr.second)))
815       return failure();
816     return success();
817   };
818   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
819 }
820 
821 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
822   if (!hasValueInScope(result)) {
823     return result.getDefiningOp()->emitOpError(
824         "result variable for the operation has not been declared");
825   }
826   os << getOrCreateName(result) << " = ";
827   return success();
828 }
829 
830 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
831                                                   bool trailingSemicolon) {
832   if (hasValueInScope(result)) {
833     return result.getDefiningOp()->emitError(
834         "result variable for the operation already declared");
835   }
836   if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
837     return failure();
838   os << " " << getOrCreateName(result);
839   if (trailingSemicolon)
840     os << ";\n";
841   return success();
842 }
843 
844 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
845   switch (op.getNumResults()) {
846   case 0:
847     break;
848   case 1: {
849     OpResult result = op.getResult(0);
850     if (shouldDeclareVariablesAtTop()) {
851       if (failed(emitVariableAssignment(result)))
852         return failure();
853     } else {
854       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
855         return failure();
856       os << " = ";
857     }
858     break;
859   }
860   default:
861     if (!shouldDeclareVariablesAtTop()) {
862       for (OpResult result : op.getResults()) {
863         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
864           return failure();
865       }
866     }
867     os << "std::tie(";
868     interleaveComma(op.getResults(), os,
869                     [&](Value result) { os << getOrCreateName(result); });
870     os << ") = ";
871   }
872   return success();
873 }
874 
875 LogicalResult CppEmitter::emitLabel(Block &block) {
876   if (!hasBlockLabel(block))
877     return block.getParentOp()->emitError("label for block not found");
878   os << getOrCreateName(block) << ":\n";
879   return success();
880 }
881 
882 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
883   LogicalResult status =
884       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
885           // EmitC ops.
886           .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
887                 emitc::IncludeOp>(
888               [&](auto op) { return printOperation(*this, op); })
889           // SCF ops.
890           .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
891               [&](auto op) { return printOperation(*this, op); })
892           // Standard ops.
893           .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
894                 ModuleOp, ReturnOp>(
895               [&](auto op) { return printOperation(*this, op); })
896           .Default([&](Operation *) {
897             return op.emitOpError("unable to find printer for op");
898           });
899 
900   if (failed(status))
901     return failure();
902   os << (trailingSemicolon ? ";\n" : "\n");
903   return success();
904 }
905 
906 LogicalResult CppEmitter::emitType(Location loc, Type type) {
907   if (auto iType = type.dyn_cast<IntegerType>()) {
908     switch (iType.getWidth()) {
909     case 1:
910       return (os << "bool"), success();
911     case 8:
912     case 16:
913     case 32:
914     case 64:
915       if (shouldMapToUnsigned(iType.getSignedness()))
916         return (os << "uint" << iType.getWidth() << "_t"), success();
917       else
918         return (os << "int" << iType.getWidth() << "_t"), success();
919     default:
920       return emitError(loc, "cannot emit integer type ") << type;
921     }
922   }
923   if (auto fType = type.dyn_cast<FloatType>()) {
924     switch (fType.getWidth()) {
925     case 32:
926       return (os << "float"), success();
927     case 64:
928       return (os << "double"), success();
929     default:
930       return emitError(loc, "cannot emit float type ") << type;
931     }
932   }
933   if (auto iType = type.dyn_cast<IndexType>())
934     return (os << "size_t"), success();
935   if (auto tType = type.dyn_cast<TensorType>()) {
936     if (!tType.hasRank())
937       return emitError(loc, "cannot emit unranked tensor type");
938     if (!tType.hasStaticShape())
939       return emitError(loc, "cannot emit tensor type with non static shape");
940     os << "Tensor<";
941     if (failed(emitType(loc, tType.getElementType())))
942       return failure();
943     auto shape = tType.getShape();
944     for (auto dimSize : shape) {
945       os << ", ";
946       os << dimSize;
947     }
948     os << ">";
949     return success();
950   }
951   if (auto tType = type.dyn_cast<TupleType>())
952     return emitTupleType(loc, tType.getTypes());
953   if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
954     os << oType.getValue();
955     return success();
956   }
957   return emitError(loc, "cannot emit type ") << type;
958 }
959 
960 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
961   switch (types.size()) {
962   case 0:
963     os << "void";
964     return success();
965   case 1:
966     return emitType(loc, types.front());
967   default:
968     return emitTupleType(loc, types);
969   }
970 }
971 
972 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
973   os << "std::tuple<";
974   if (failed(interleaveCommaWithError(
975           types, os, [&](Type type) { return emitType(loc, type); })))
976     return failure();
977   os << ">";
978   return success();
979 }
980 
981 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
982                                     bool declareVariablesAtTop) {
983   CppEmitter emitter(os, declareVariablesAtTop);
984   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
985 }
986