1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/EmitC/IR/EmitC.h"
10 #include "mlir/Dialect/SCF/SCF.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/IndentedOstream.h"
17 #include "mlir/Target/Cpp/CppEmitter.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 #define DEBUG_TYPE "translate-to-cpp"
26 
27 using namespace mlir;
28 using namespace mlir::emitc;
29 using llvm::formatv;
30 
31 /// Convenience functions to produce interleaved output with functions returning
32 /// a LogicalResult. This is different than those in STLExtras as functions used
33 /// on each element doesn't return a string.
34 template <typename ForwardIterator, typename UnaryFunctor,
35           typename NullaryFunctor>
36 inline LogicalResult
37 interleaveWithError(ForwardIterator begin, ForwardIterator end,
38                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
39   if (begin == end)
40     return success();
41   if (failed(eachFn(*begin)))
42     return failure();
43   ++begin;
44   for (; begin != end; ++begin) {
45     betweenFn();
46     if (failed(eachFn(*begin)))
47       return failure();
48   }
49   return success();
50 }
51 
52 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
53 inline LogicalResult interleaveWithError(const Container &c,
54                                          UnaryFunctor eachFn,
55                                          NullaryFunctor betweenFn) {
56   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
57 }
58 
59 template <typename Container, typename UnaryFunctor>
60 inline LogicalResult interleaveCommaWithError(const Container &c,
61                                               raw_ostream &os,
62                                               UnaryFunctor eachFn) {
63   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
64 }
65 
66 namespace {
67 /// Emitter that uses dialect specific emitters to emit C++ code.
68 struct CppEmitter {
69   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
70 
71   /// Emits attribute or returns failure.
72   LogicalResult emitAttribute(Location loc, Attribute attr);
73 
74   /// Emits operation 'op' with/without training semicolon or returns failure.
75   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
76 
77   /// Emits type 'type' or returns failure.
78   LogicalResult emitType(Location loc, Type type);
79 
80   /// Emits array of types as a std::tuple of the emitted types.
81   /// - emits void for an empty array;
82   /// - emits the type of the only element for arrays of size one;
83   /// - emits a std::tuple otherwise;
84   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
85 
86   /// Emits array of types as a std::tuple of the emitted types independently of
87   /// the array size.
88   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
89 
90   /// Emits an assignment for a variable which has been declared previously.
91   LogicalResult emitVariableAssignment(OpResult result);
92 
93   /// Emits a variable declaration for a result of an operation.
94   LogicalResult emitVariableDeclaration(OpResult result,
95                                         bool trailingSemicolon);
96 
97   /// Emits the variable declaration and assignment prefix for 'op'.
98   /// - emits separate variable followed by std::tie for multi-valued operation;
99   /// - emits single type followed by variable for single result;
100   /// - emits nothing if no value produced by op;
101   /// Emits final '=' operator where a type is produced. Returns failure if
102   /// any result type could not be converted.
103   LogicalResult emitAssignPrefix(Operation &op);
104 
105   /// Emits a label for the block.
106   LogicalResult emitLabel(Block &block);
107 
108   /// Emits the operands and atttributes of the operation. All operands are
109   /// emitted first and then all attributes in alphabetical order.
110   LogicalResult emitOperandsAndAttributes(Operation &op,
111                                           ArrayRef<StringRef> exclude = {});
112 
113   /// Emits the operands of the operation. All operands are emitted in order.
114   LogicalResult emitOperands(Operation &op);
115 
116   /// Return the existing or a new name for a Value.
117   StringRef getOrCreateName(Value val);
118 
119   /// Return the existing or a new label of a Block.
120   StringRef getOrCreateName(Block &block);
121 
122   /// Whether to map an mlir integer to a unsigned integer in C++.
123   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
124 
125   /// RAII helper function to manage entering/exiting C++ scopes.
126   struct Scope {
127     Scope(CppEmitter &emitter)
128         : valueMapperScope(emitter.valueMapper),
129           blockMapperScope(emitter.blockMapper), emitter(emitter) {
130       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
131       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
132     }
133     ~Scope() {
134       emitter.valueInScopeCount.pop();
135       emitter.labelInScopeCount.pop();
136     }
137 
138   private:
139     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
140     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
141     CppEmitter &emitter;
142   };
143 
144   /// Returns wether the Value is assigned to a C++ variable in the scope.
145   bool hasValueInScope(Value val);
146 
147   // Returns whether a label is assigned to the block.
148   bool hasBlockLabel(Block &block);
149 
150   /// Returns the output stream.
151   raw_indented_ostream &ostream() { return os; };
152 
153   /// Returns if all variables for op results and basic block arguments need to
154   /// be declared at the beginning of a function.
155   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
156 
157 private:
158   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
159   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
160 
161   /// Output stream to emit to.
162   raw_indented_ostream os;
163 
164   /// Boolean to enforce that all variables for op results and block
165   /// arguments are declared at the beginning of the function. This also
166   /// includes results from ops located in nested regions.
167   bool declareVariablesAtTop;
168 
169   /// Map from value to name of C++ variable that contain the name.
170   ValueMapper valueMapper;
171 
172   /// Map from block to name of C++ label.
173   BlockMapper blockMapper;
174 
175   /// The number of values in the current scope. This is used to declare the
176   /// names of values in a scope.
177   std::stack<int64_t> valueInScopeCount;
178   std::stack<int64_t> labelInScopeCount;
179 };
180 } // namespace
181 
182 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
183                                      Attribute value) {
184   OpResult result = operation->getResult(0);
185 
186   // Only emit an assignment as the variable was already declared when printing
187   // the FuncOp.
188   if (emitter.shouldDeclareVariablesAtTop()) {
189     // Skip the assignment if the emitc.constant has no value.
190     if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
191       if (oAttr.getValue().empty())
192         return success();
193     }
194 
195     if (failed(emitter.emitVariableAssignment(result)))
196       return failure();
197     return emitter.emitAttribute(operation->getLoc(), value);
198   }
199 
200   // Emit a variable declaration for an emitc.constant op without value.
201   if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
202     if (oAttr.getValue().empty())
203       // The semicolon gets printed by the emitOperation function.
204       return emitter.emitVariableDeclaration(result,
205                                              /*trailingSemicolon=*/false);
206   }
207 
208   // Emit a variable declaration.
209   if (failed(emitter.emitAssignPrefix(*operation)))
210     return failure();
211   return emitter.emitAttribute(operation->getLoc(), value);
212 }
213 
214 static LogicalResult printOperation(CppEmitter &emitter,
215                                     emitc::ConstantOp constantOp) {
216   Operation *operation = constantOp.getOperation();
217   Attribute value = constantOp.value();
218 
219   return printConstantOp(emitter, operation, value);
220 }
221 
222 static LogicalResult printOperation(CppEmitter &emitter,
223                                     arith::ConstantOp constantOp) {
224   Operation *operation = constantOp.getOperation();
225   Attribute value = constantOp.getValue();
226 
227   return printConstantOp(emitter, operation, value);
228 }
229 
230 static LogicalResult printOperation(CppEmitter &emitter,
231                                     mlir::ConstantOp constantOp) {
232   Operation *operation = constantOp.getOperation();
233   Attribute value = constantOp.getValue();
234 
235   return printConstantOp(emitter, operation, value);
236 }
237 
238 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) {
239   raw_ostream &os = emitter.ostream();
240   Block &successor = *branchOp.getSuccessor();
241 
242   for (auto pair :
243        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
244     Value &operand = std::get<0>(pair);
245     BlockArgument &argument = std::get<1>(pair);
246     os << emitter.getOrCreateName(argument) << " = "
247        << emitter.getOrCreateName(operand) << ";\n";
248   }
249 
250   os << "goto ";
251   if (!(emitter.hasBlockLabel(successor)))
252     return branchOp.emitOpError("unable to find label for successor block");
253   os << emitter.getOrCreateName(successor);
254   return success();
255 }
256 
257 static LogicalResult printOperation(CppEmitter &emitter,
258                                     CondBranchOp condBranchOp) {
259   raw_indented_ostream &os = emitter.ostream();
260   Block &trueSuccessor = *condBranchOp.getTrueDest();
261   Block &falseSuccessor = *condBranchOp.getFalseDest();
262 
263   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
264      << ") {\n";
265 
266   os.indent();
267 
268   // If condition is true.
269   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
270                              trueSuccessor.getArguments())) {
271     Value &operand = std::get<0>(pair);
272     BlockArgument &argument = std::get<1>(pair);
273     os << emitter.getOrCreateName(argument) << " = "
274        << emitter.getOrCreateName(operand) << ";\n";
275   }
276 
277   os << "goto ";
278   if (!(emitter.hasBlockLabel(trueSuccessor))) {
279     return condBranchOp.emitOpError("unable to find label for successor block");
280   }
281   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
282   os.unindent() << "} else {\n";
283   os.indent();
284   // If condition is false.
285   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
286                              falseSuccessor.getArguments())) {
287     Value &operand = std::get<0>(pair);
288     BlockArgument &argument = std::get<1>(pair);
289     os << emitter.getOrCreateName(argument) << " = "
290        << emitter.getOrCreateName(operand) << ";\n";
291   }
292 
293   os << "goto ";
294   if (!(emitter.hasBlockLabel(falseSuccessor))) {
295     return condBranchOp.emitOpError()
296            << "unable to find label for successor block";
297   }
298   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
299   os.unindent() << "}";
300   return success();
301 }
302 
303 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) {
304   if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
305     return failure();
306 
307   raw_ostream &os = emitter.ostream();
308   os << callOp.getCallee() << "(";
309   if (failed(emitter.emitOperands(*callOp.getOperation())))
310     return failure();
311   os << ")";
312   return success();
313 }
314 
315 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
316   raw_ostream &os = emitter.ostream();
317   Operation &op = *callOp.getOperation();
318 
319   if (failed(emitter.emitAssignPrefix(op)))
320     return failure();
321   os << callOp.callee();
322 
323   auto emitArgs = [&](Attribute attr) -> LogicalResult {
324     if (auto t = attr.dyn_cast<IntegerAttr>()) {
325       // Index attributes are treated specially as operand index.
326       if (t.getType().isIndex()) {
327         int64_t idx = t.getInt();
328         if ((idx < 0) || (idx >= op.getNumOperands()))
329           return op.emitOpError("invalid operand index");
330         if (!emitter.hasValueInScope(op.getOperand(idx)))
331           return op.emitOpError("operand ")
332                  << idx << "'s value not defined in scope";
333         os << emitter.getOrCreateName(op.getOperand(idx));
334         return success();
335       }
336     }
337     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
338       return failure();
339 
340     return success();
341   };
342 
343   if (callOp.template_args()) {
344     os << "<";
345     if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs)))
346       return failure();
347     os << ">";
348   }
349 
350   os << "(";
351 
352   LogicalResult emittedArgs =
353       callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs)
354                     : emitter.emitOperands(op);
355   if (failed(emittedArgs))
356     return failure();
357   os << ")";
358   return success();
359 }
360 
361 static LogicalResult printOperation(CppEmitter &emitter,
362                                     emitc::ApplyOp applyOp) {
363   raw_ostream &os = emitter.ostream();
364   Operation &op = *applyOp.getOperation();
365 
366   if (failed(emitter.emitAssignPrefix(op)))
367     return failure();
368   os << applyOp.applicableOperator();
369   os << emitter.getOrCreateName(applyOp.getOperand());
370 
371   return success();
372 }
373 
374 static LogicalResult printOperation(CppEmitter &emitter,
375                                     emitc::IncludeOp includeOp) {
376   raw_ostream &os = emitter.ostream();
377 
378   os << "#include ";
379   if (includeOp.is_standard_include())
380     os << "<" << includeOp.include() << ">";
381   else
382     os << "\"" << includeOp.include() << "\"";
383 
384   return success();
385 }
386 
387 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
388 
389   raw_indented_ostream &os = emitter.ostream();
390 
391   OperandRange operands = forOp.getIterOperands();
392   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
393   Operation::result_range results = forOp.getResults();
394 
395   if (!emitter.shouldDeclareVariablesAtTop()) {
396     for (OpResult result : results) {
397       if (failed(emitter.emitVariableDeclaration(result,
398                                                  /*trailingSemicolon=*/true)))
399         return failure();
400     }
401   }
402 
403   for (auto pair : llvm::zip(iterArgs, operands)) {
404     if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
405       return failure();
406     os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
407     os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
408     os << "\n";
409   }
410 
411   os << "for (";
412   if (failed(
413           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
414     return failure();
415   os << " ";
416   os << emitter.getOrCreateName(forOp.getInductionVar());
417   os << " = ";
418   os << emitter.getOrCreateName(forOp.lowerBound());
419   os << "; ";
420   os << emitter.getOrCreateName(forOp.getInductionVar());
421   os << " < ";
422   os << emitter.getOrCreateName(forOp.upperBound());
423   os << "; ";
424   os << emitter.getOrCreateName(forOp.getInductionVar());
425   os << " += ";
426   os << emitter.getOrCreateName(forOp.step());
427   os << ") {\n";
428   os.indent();
429 
430   Region &forRegion = forOp.region();
431   auto regionOps = forRegion.getOps();
432 
433   // We skip the trailing yield op because this updates the result variables
434   // of the for op in the generated code. Instead we update the iterArgs at
435   // the end of a loop iteration and set the result variables after the for
436   // loop.
437   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
438     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
439       return failure();
440   }
441 
442   Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
443   // Copy yield operands into iterArgs at the end of a loop iteration.
444   for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
445     BlockArgument iterArg = std::get<0>(pair);
446     Value operand = std::get<1>(pair);
447     os << emitter.getOrCreateName(iterArg) << " = "
448        << emitter.getOrCreateName(operand) << ";\n";
449   }
450 
451   os.unindent() << "}";
452 
453   // Copy iterArgs into results after the for loop.
454   for (auto pair : llvm::zip(results, iterArgs)) {
455     OpResult result = std::get<0>(pair);
456     BlockArgument iterArg = std::get<1>(pair);
457     os << "\n"
458        << emitter.getOrCreateName(result) << " = "
459        << emitter.getOrCreateName(iterArg) << ";";
460   }
461 
462   return success();
463 }
464 
465 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
466   raw_indented_ostream &os = emitter.ostream();
467 
468   if (!emitter.shouldDeclareVariablesAtTop()) {
469     for (OpResult result : ifOp.getResults()) {
470       if (failed(emitter.emitVariableDeclaration(result,
471                                                  /*trailingSemicolon=*/true)))
472         return failure();
473     }
474   }
475 
476   os << "if (";
477   if (failed(emitter.emitOperands(*ifOp.getOperation())))
478     return failure();
479   os << ") {\n";
480   os.indent();
481 
482   Region &thenRegion = ifOp.thenRegion();
483   for (Operation &op : thenRegion.getOps()) {
484     // Note: This prints a superfluous semicolon if the terminating yield op has
485     // zero results.
486     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
487       return failure();
488   }
489 
490   os.unindent() << "}";
491 
492   Region &elseRegion = ifOp.elseRegion();
493   if (!elseRegion.empty()) {
494     os << " else {\n";
495     os.indent();
496 
497     for (Operation &op : elseRegion.getOps()) {
498       // Note: This prints a superfluous semicolon if the terminating yield op
499       // has zero results.
500       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
501         return failure();
502     }
503 
504     os.unindent() << "}";
505   }
506 
507   return success();
508 }
509 
510 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
511   raw_ostream &os = emitter.ostream();
512   Operation &parentOp = *yieldOp.getOperation()->getParentOp();
513 
514   if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
515     return yieldOp.emitError("number of operands does not to match the number "
516                              "of the parent op's results");
517   }
518 
519   if (failed(interleaveWithError(
520           llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
521           [&](auto pair) -> LogicalResult {
522             auto result = std::get<0>(pair);
523             auto operand = std::get<1>(pair);
524             os << emitter.getOrCreateName(result) << " = ";
525 
526             if (!emitter.hasValueInScope(operand))
527               return yieldOp.emitError("operand value not in scope");
528             os << emitter.getOrCreateName(operand);
529             return success();
530           },
531           [&]() { os << ";\n"; })))
532     return failure();
533 
534   return success();
535 }
536 
537 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) {
538   raw_ostream &os = emitter.ostream();
539   os << "return";
540   switch (returnOp.getNumOperands()) {
541   case 0:
542     return success();
543   case 1:
544     os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
545     return success(emitter.hasValueInScope(returnOp.getOperand(0)));
546   default:
547     os << " std::make_tuple(";
548     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
549       return failure();
550     os << ")";
551     return success();
552   }
553 }
554 
555 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
556   CppEmitter::Scope scope(emitter);
557 
558   for (Operation &op : moduleOp) {
559     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
560       return failure();
561   }
562   return success();
563 }
564 
565 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) {
566   // We need to declare variables at top if the function has multiple blocks.
567   if (!emitter.shouldDeclareVariablesAtTop() &&
568       functionOp.getBlocks().size() > 1) {
569     return functionOp.emitOpError(
570         "with multiple blocks needs variables declared at top");
571   }
572 
573   CppEmitter::Scope scope(emitter);
574   raw_indented_ostream &os = emitter.ostream();
575   if (failed(emitter.emitTypes(functionOp.getLoc(),
576                                functionOp.getType().getResults())))
577     return failure();
578   os << " " << functionOp.getName();
579 
580   os << "(";
581   if (failed(interleaveCommaWithError(
582           functionOp.getArguments(), os,
583           [&](BlockArgument arg) -> LogicalResult {
584             if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
585               return failure();
586             os << " " << emitter.getOrCreateName(arg);
587             return success();
588           })))
589     return failure();
590   os << ") {\n";
591   os.indent();
592   if (emitter.shouldDeclareVariablesAtTop()) {
593     // Declare all variables that hold op results including those from nested
594     // regions.
595     WalkResult result =
596         functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
597           for (OpResult result : op->getResults()) {
598             if (failed(emitter.emitVariableDeclaration(
599                     result, /*trailingSemicolon=*/true))) {
600               return WalkResult(
601                   op->emitError("unable to declare result variable for op"));
602             }
603           }
604           return WalkResult::advance();
605         });
606     if (result.wasInterrupted())
607       return failure();
608   }
609 
610   Region::BlockListType &blocks = functionOp.getBlocks();
611   // Create label names for basic blocks.
612   for (Block &block : blocks) {
613     emitter.getOrCreateName(block);
614   }
615 
616   // Declare variables for basic block arguments.
617   for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
618     Block &block = *it;
619     for (BlockArgument &arg : block.getArguments()) {
620       if (emitter.hasValueInScope(arg))
621         return functionOp.emitOpError(" block argument #")
622                << arg.getArgNumber() << " is out of scope";
623       if (failed(
624               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
625         return failure();
626       }
627       os << " " << emitter.getOrCreateName(arg) << ";\n";
628     }
629   }
630 
631   for (Block &block : blocks) {
632     // Only print a label if there is more than one block.
633     if (blocks.size() > 1) {
634       if (failed(emitter.emitLabel(block)))
635         return failure();
636     }
637     for (Operation &op : block.getOperations()) {
638       // When generating code for an scf.if or std.cond_br op no semicolon needs
639       // to be printed after the closing brace.
640       // When generating code for an scf.for op, printing a trailing semicolon
641       // is handled within the printOperation function.
642       bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op);
643 
644       if (failed(emitter.emitOperation(
645               op, /*trailingSemicolon=*/trailingSemicolon)))
646         return failure();
647     }
648   }
649   os.unindent() << "}\n";
650   return success();
651 }
652 
653 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
654     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
655   valueInScopeCount.push(0);
656   labelInScopeCount.push(0);
657 }
658 
659 /// Return the existing or a new name for a Value.
660 StringRef CppEmitter::getOrCreateName(Value val) {
661   if (!valueMapper.count(val))
662     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
663   return *valueMapper.begin(val);
664 }
665 
666 /// Return the existing or a new label for a Block.
667 StringRef CppEmitter::getOrCreateName(Block &block) {
668   if (!blockMapper.count(&block))
669     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
670   return *blockMapper.begin(&block);
671 }
672 
673 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
674   switch (val) {
675   case IntegerType::Signless:
676     return false;
677   case IntegerType::Signed:
678     return false;
679   case IntegerType::Unsigned:
680     return true;
681   }
682   llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
683 }
684 
685 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
686 
687 bool CppEmitter::hasBlockLabel(Block &block) {
688   return blockMapper.count(&block);
689 }
690 
691 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
692   auto printInt = [&](APInt val, bool isUnsigned) {
693     if (val.getBitWidth() == 1) {
694       if (val.getBoolValue())
695         os << "true";
696       else
697         os << "false";
698     } else {
699       SmallString<128> strValue;
700       val.toString(strValue, 10, !isUnsigned, false);
701       os << strValue;
702     }
703   };
704 
705   auto printFloat = [&](APFloat val) {
706     if (val.isFinite()) {
707       SmallString<128> strValue;
708       // Use default values of toString except don't truncate zeros.
709       val.toString(strValue, 0, 0, false);
710       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
711       case llvm::APFloatBase::S_IEEEsingle:
712         os << "(float)";
713         break;
714       case llvm::APFloatBase::S_IEEEdouble:
715         os << "(double)";
716         break;
717       default:
718         break;
719       };
720       os << strValue;
721     } else if (val.isNaN()) {
722       os << "NAN";
723     } else if (val.isInfinity()) {
724       if (val.isNegative())
725         os << "-";
726       os << "INFINITY";
727     }
728   };
729 
730   // Print floating point attributes.
731   if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
732     printFloat(fAttr.getValue());
733     return success();
734   }
735   if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
736     os << '{';
737     interleaveComma(dense, os, [&](APFloat val) { printFloat(val); });
738     os << '}';
739     return success();
740   }
741 
742   // Print integer attributes.
743   if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
744     if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
745       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
746       return success();
747     }
748     if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
749       printInt(iAttr.getValue(), false);
750       return success();
751     }
752   }
753   if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
754     if (auto iType = dense.getType()
755                          .cast<TensorType>()
756                          .getElementType()
757                          .dyn_cast<IntegerType>()) {
758       os << '{';
759       interleaveComma(dense, os, [&](APInt val) {
760         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
761       });
762       os << '}';
763       return success();
764     }
765     if (auto iType = dense.getType()
766                          .cast<TensorType>()
767                          .getElementType()
768                          .dyn_cast<IndexType>()) {
769       os << '{';
770       interleaveComma(dense, os, [&](APInt val) { printInt(val, false); });
771       os << '}';
772       return success();
773     }
774   }
775 
776   // Print opaque attributes.
777   if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
778     os << oAttr.getValue();
779     return success();
780   }
781 
782   // Print symbolic reference attributes.
783   if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
784     if (sAttr.getNestedReferences().size() > 1)
785       return emitError(loc, "attribute has more than 1 nested reference");
786     os << sAttr.getRootReference().getValue();
787     return success();
788   }
789 
790   // Print type attributes.
791   if (auto type = attr.dyn_cast<TypeAttr>())
792     return emitType(loc, type.getValue());
793 
794   return emitError(loc, "cannot emit attribute of type ") << attr.getType();
795 }
796 
797 LogicalResult CppEmitter::emitOperands(Operation &op) {
798   auto emitOperandName = [&](Value result) -> LogicalResult {
799     if (!hasValueInScope(result))
800       return op.emitOpError() << "operand value not in scope";
801     os << getOrCreateName(result);
802     return success();
803   };
804   return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
805 }
806 
807 LogicalResult
808 CppEmitter::emitOperandsAndAttributes(Operation &op,
809                                       ArrayRef<StringRef> exclude) {
810   if (failed(emitOperands(op)))
811     return failure();
812   // Insert comma in between operands and non-filtered attributes if needed.
813   if (op.getNumOperands() > 0) {
814     for (NamedAttribute attr : op.getAttrs()) {
815       if (!llvm::is_contained(exclude, attr.first.strref())) {
816         os << ", ";
817         break;
818       }
819     }
820   }
821   // Emit attributes.
822   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
823     if (llvm::is_contained(exclude, attr.first.strref()))
824       return success();
825     os << "/* " << attr.first << " */";
826     if (failed(emitAttribute(op.getLoc(), attr.second)))
827       return failure();
828     return success();
829   };
830   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
831 }
832 
833 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
834   if (!hasValueInScope(result)) {
835     return result.getDefiningOp()->emitOpError(
836         "result variable for the operation has not been declared");
837   }
838   os << getOrCreateName(result) << " = ";
839   return success();
840 }
841 
842 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
843                                                   bool trailingSemicolon) {
844   if (hasValueInScope(result)) {
845     return result.getDefiningOp()->emitError(
846         "result variable for the operation already declared");
847   }
848   if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
849     return failure();
850   os << " " << getOrCreateName(result);
851   if (trailingSemicolon)
852     os << ";\n";
853   return success();
854 }
855 
856 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
857   switch (op.getNumResults()) {
858   case 0:
859     break;
860   case 1: {
861     OpResult result = op.getResult(0);
862     if (shouldDeclareVariablesAtTop()) {
863       if (failed(emitVariableAssignment(result)))
864         return failure();
865     } else {
866       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
867         return failure();
868       os << " = ";
869     }
870     break;
871   }
872   default:
873     if (!shouldDeclareVariablesAtTop()) {
874       for (OpResult result : op.getResults()) {
875         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
876           return failure();
877       }
878     }
879     os << "std::tie(";
880     interleaveComma(op.getResults(), os,
881                     [&](Value result) { os << getOrCreateName(result); });
882     os << ") = ";
883   }
884   return success();
885 }
886 
887 LogicalResult CppEmitter::emitLabel(Block &block) {
888   if (!hasBlockLabel(block))
889     return block.getParentOp()->emitError("label for block not found");
890   // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
891   // label instead of using `getOStream`.
892   os.getOStream() << getOrCreateName(block) << ":\n";
893   return success();
894 }
895 
896 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
897   LogicalResult status =
898       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
899           // EmitC ops.
900           .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp,
901                 emitc::IncludeOp>(
902               [&](auto op) { return printOperation(*this, op); })
903           // SCF ops.
904           .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
905               [&](auto op) { return printOperation(*this, op); })
906           // Standard ops.
907           .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp,
908                 ModuleOp, ReturnOp>(
909               [&](auto op) { return printOperation(*this, op); })
910           // Arithmetic ops.
911           .Case<arith::ConstantOp>(
912               [&](auto op) { return printOperation(*this, op); })
913           .Default([&](Operation *) {
914             return op.emitOpError("unable to find printer for op");
915           });
916 
917   if (failed(status))
918     return failure();
919   os << (trailingSemicolon ? ";\n" : "\n");
920   return success();
921 }
922 
923 LogicalResult CppEmitter::emitType(Location loc, Type type) {
924   if (auto iType = type.dyn_cast<IntegerType>()) {
925     switch (iType.getWidth()) {
926     case 1:
927       return (os << "bool"), success();
928     case 8:
929     case 16:
930     case 32:
931     case 64:
932       if (shouldMapToUnsigned(iType.getSignedness()))
933         return (os << "uint" << iType.getWidth() << "_t"), success();
934       else
935         return (os << "int" << iType.getWidth() << "_t"), success();
936     default:
937       return emitError(loc, "cannot emit integer type ") << type;
938     }
939   }
940   if (auto fType = type.dyn_cast<FloatType>()) {
941     switch (fType.getWidth()) {
942     case 32:
943       return (os << "float"), success();
944     case 64:
945       return (os << "double"), success();
946     default:
947       return emitError(loc, "cannot emit float type ") << type;
948     }
949   }
950   if (auto iType = type.dyn_cast<IndexType>())
951     return (os << "size_t"), success();
952   if (auto tType = type.dyn_cast<TensorType>()) {
953     if (!tType.hasRank())
954       return emitError(loc, "cannot emit unranked tensor type");
955     if (!tType.hasStaticShape())
956       return emitError(loc, "cannot emit tensor type with non static shape");
957     os << "Tensor<";
958     if (failed(emitType(loc, tType.getElementType())))
959       return failure();
960     auto shape = tType.getShape();
961     for (auto dimSize : shape) {
962       os << ", ";
963       os << dimSize;
964     }
965     os << ">";
966     return success();
967   }
968   if (auto tType = type.dyn_cast<TupleType>())
969     return emitTupleType(loc, tType.getTypes());
970   if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
971     os << oType.getValue();
972     return success();
973   }
974   return emitError(loc, "cannot emit type ") << type;
975 }
976 
977 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
978   switch (types.size()) {
979   case 0:
980     os << "void";
981     return success();
982   case 1:
983     return emitType(loc, types.front());
984   default:
985     return emitTupleType(loc, types);
986   }
987 }
988 
989 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
990   os << "std::tuple<";
991   if (failed(interleaveCommaWithError(
992           types, os, [&](Type type) { return emitType(loc, type); })))
993     return failure();
994   os << ">";
995   return success();
996 }
997 
998 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
999                                     bool declareVariablesAtTop) {
1000   CppEmitter emitter(os, declareVariablesAtTop);
1001   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1002 }
1003