1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11 #include "mlir/Dialect/EmitC/IR/EmitC.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/IndentedOstream.h"
19 #include "mlir/Target/Cpp/CppEmitter.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include <utility>
27 
28 #define DEBUG_TYPE "translate-to-cpp"
29 
30 using namespace mlir;
31 using namespace mlir::emitc;
32 using llvm::formatv;
33 
34 /// Convenience functions to produce interleaved output with functions returning
35 /// a LogicalResult. This is different than those in STLExtras as functions used
36 /// on each element doesn't return a string.
37 template <typename ForwardIterator, typename UnaryFunctor,
38           typename NullaryFunctor>
39 inline LogicalResult
interleaveWithError(ForwardIterator begin,ForwardIterator end,UnaryFunctor eachFn,NullaryFunctor betweenFn)40 interleaveWithError(ForwardIterator begin, ForwardIterator end,
41                     UnaryFunctor eachFn, NullaryFunctor betweenFn) {
42   if (begin == end)
43     return success();
44   if (failed(eachFn(*begin)))
45     return failure();
46   ++begin;
47   for (; begin != end; ++begin) {
48     betweenFn();
49     if (failed(eachFn(*begin)))
50       return failure();
51   }
52   return success();
53 }
54 
55 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
interleaveWithError(const Container & c,UnaryFunctor eachFn,NullaryFunctor betweenFn)56 inline LogicalResult interleaveWithError(const Container &c,
57                                          UnaryFunctor eachFn,
58                                          NullaryFunctor betweenFn) {
59   return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
60 }
61 
62 template <typename Container, typename UnaryFunctor>
interleaveCommaWithError(const Container & c,raw_ostream & os,UnaryFunctor eachFn)63 inline LogicalResult interleaveCommaWithError(const Container &c,
64                                               raw_ostream &os,
65                                               UnaryFunctor eachFn) {
66   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
67 }
68 
69 namespace {
70 /// Emitter that uses dialect specific emitters to emit C++ code.
71 struct CppEmitter {
72   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
73 
74   /// Emits attribute or returns failure.
75   LogicalResult emitAttribute(Location loc, Attribute attr);
76 
77   /// Emits operation 'op' with/without training semicolon or returns failure.
78   LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
79 
80   /// Emits type 'type' or returns failure.
81   LogicalResult emitType(Location loc, Type type);
82 
83   /// Emits array of types as a std::tuple of the emitted types.
84   /// - emits void for an empty array;
85   /// - emits the type of the only element for arrays of size one;
86   /// - emits a std::tuple otherwise;
87   LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
88 
89   /// Emits array of types as a std::tuple of the emitted types independently of
90   /// the array size.
91   LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
92 
93   /// Emits an assignment for a variable which has been declared previously.
94   LogicalResult emitVariableAssignment(OpResult result);
95 
96   /// Emits a variable declaration for a result of an operation.
97   LogicalResult emitVariableDeclaration(OpResult result,
98                                         bool trailingSemicolon);
99 
100   /// Emits the variable declaration and assignment prefix for 'op'.
101   /// - emits separate variable followed by std::tie for multi-valued operation;
102   /// - emits single type followed by variable for single result;
103   /// - emits nothing if no value produced by op;
104   /// Emits final '=' operator where a type is produced. Returns failure if
105   /// any result type could not be converted.
106   LogicalResult emitAssignPrefix(Operation &op);
107 
108   /// Emits a label for the block.
109   LogicalResult emitLabel(Block &block);
110 
111   /// Emits the operands and atttributes of the operation. All operands are
112   /// emitted first and then all attributes in alphabetical order.
113   LogicalResult emitOperandsAndAttributes(Operation &op,
114                                           ArrayRef<StringRef> exclude = {});
115 
116   /// Emits the operands of the operation. All operands are emitted in order.
117   LogicalResult emitOperands(Operation &op);
118 
119   /// Return the existing or a new name for a Value.
120   StringRef getOrCreateName(Value val);
121 
122   /// Return the existing or a new label of a Block.
123   StringRef getOrCreateName(Block &block);
124 
125   /// Whether to map an mlir integer to a unsigned integer in C++.
126   bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
127 
128   /// RAII helper function to manage entering/exiting C++ scopes.
129   struct Scope {
Scope__anon47891eb50211::CppEmitter::Scope130     Scope(CppEmitter &emitter)
131         : valueMapperScope(emitter.valueMapper),
132           blockMapperScope(emitter.blockMapper), emitter(emitter) {
133       emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
134       emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
135     }
~Scope__anon47891eb50211::CppEmitter::Scope136     ~Scope() {
137       emitter.valueInScopeCount.pop();
138       emitter.labelInScopeCount.pop();
139     }
140 
141   private:
142     llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
143     llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
144     CppEmitter &emitter;
145   };
146 
147   /// Returns wether the Value is assigned to a C++ variable in the scope.
148   bool hasValueInScope(Value val);
149 
150   // Returns whether a label is assigned to the block.
151   bool hasBlockLabel(Block &block);
152 
153   /// Returns the output stream.
ostream__anon47891eb50211::CppEmitter154   raw_indented_ostream &ostream() { return os; };
155 
156   /// Returns if all variables for op results and basic block arguments need to
157   /// be declared at the beginning of a function.
shouldDeclareVariablesAtTop__anon47891eb50211::CppEmitter158   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
159 
160 private:
161   using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
162   using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
163 
164   /// Output stream to emit to.
165   raw_indented_ostream os;
166 
167   /// Boolean to enforce that all variables for op results and block
168   /// arguments are declared at the beginning of the function. This also
169   /// includes results from ops located in nested regions.
170   bool declareVariablesAtTop;
171 
172   /// Map from value to name of C++ variable that contain the name.
173   ValueMapper valueMapper;
174 
175   /// Map from block to name of C++ label.
176   BlockMapper blockMapper;
177 
178   /// The number of values in the current scope. This is used to declare the
179   /// names of values in a scope.
180   std::stack<int64_t> valueInScopeCount;
181   std::stack<int64_t> labelInScopeCount;
182 };
183 } // namespace
184 
printConstantOp(CppEmitter & emitter,Operation * operation,Attribute value)185 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
186                                      Attribute value) {
187   OpResult result = operation->getResult(0);
188 
189   // Only emit an assignment as the variable was already declared when printing
190   // the FuncOp.
191   if (emitter.shouldDeclareVariablesAtTop()) {
192     // Skip the assignment if the emitc.constant has no value.
193     if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
194       if (oAttr.getValue().empty())
195         return success();
196     }
197 
198     if (failed(emitter.emitVariableAssignment(result)))
199       return failure();
200     return emitter.emitAttribute(operation->getLoc(), value);
201   }
202 
203   // Emit a variable declaration for an emitc.constant op without value.
204   if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
205     if (oAttr.getValue().empty())
206       // The semicolon gets printed by the emitOperation function.
207       return emitter.emitVariableDeclaration(result,
208                                              /*trailingSemicolon=*/false);
209   }
210 
211   // Emit a variable declaration.
212   if (failed(emitter.emitAssignPrefix(*operation)))
213     return failure();
214   return emitter.emitAttribute(operation->getLoc(), value);
215 }
216 
printOperation(CppEmitter & emitter,emitc::ConstantOp constantOp)217 static LogicalResult printOperation(CppEmitter &emitter,
218                                     emitc::ConstantOp constantOp) {
219   Operation *operation = constantOp.getOperation();
220   Attribute value = constantOp.getValue();
221 
222   return printConstantOp(emitter, operation, value);
223 }
224 
printOperation(CppEmitter & emitter,emitc::VariableOp variableOp)225 static LogicalResult printOperation(CppEmitter &emitter,
226                                     emitc::VariableOp variableOp) {
227   Operation *operation = variableOp.getOperation();
228   Attribute value = variableOp.getValue();
229 
230   return printConstantOp(emitter, operation, value);
231 }
232 
printOperation(CppEmitter & emitter,arith::ConstantOp constantOp)233 static LogicalResult printOperation(CppEmitter &emitter,
234                                     arith::ConstantOp constantOp) {
235   Operation *operation = constantOp.getOperation();
236   Attribute value = constantOp.getValue();
237 
238   return printConstantOp(emitter, operation, value);
239 }
240 
printOperation(CppEmitter & emitter,func::ConstantOp constantOp)241 static LogicalResult printOperation(CppEmitter &emitter,
242                                     func::ConstantOp constantOp) {
243   Operation *operation = constantOp.getOperation();
244   Attribute value = constantOp.getValueAttr();
245 
246   return printConstantOp(emitter, operation, value);
247 }
248 
printOperation(CppEmitter & emitter,cf::BranchOp branchOp)249 static LogicalResult printOperation(CppEmitter &emitter,
250                                     cf::BranchOp branchOp) {
251   raw_ostream &os = emitter.ostream();
252   Block &successor = *branchOp.getSuccessor();
253 
254   for (auto pair :
255        llvm::zip(branchOp.getOperands(), successor.getArguments())) {
256     Value &operand = std::get<0>(pair);
257     BlockArgument &argument = std::get<1>(pair);
258     os << emitter.getOrCreateName(argument) << " = "
259        << emitter.getOrCreateName(operand) << ";\n";
260   }
261 
262   os << "goto ";
263   if (!(emitter.hasBlockLabel(successor)))
264     return branchOp.emitOpError("unable to find label for successor block");
265   os << emitter.getOrCreateName(successor);
266   return success();
267 }
268 
printOperation(CppEmitter & emitter,cf::CondBranchOp condBranchOp)269 static LogicalResult printOperation(CppEmitter &emitter,
270                                     cf::CondBranchOp condBranchOp) {
271   raw_indented_ostream &os = emitter.ostream();
272   Block &trueSuccessor = *condBranchOp.getTrueDest();
273   Block &falseSuccessor = *condBranchOp.getFalseDest();
274 
275   os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
276      << ") {\n";
277 
278   os.indent();
279 
280   // If condition is true.
281   for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
282                              trueSuccessor.getArguments())) {
283     Value &operand = std::get<0>(pair);
284     BlockArgument &argument = std::get<1>(pair);
285     os << emitter.getOrCreateName(argument) << " = "
286        << emitter.getOrCreateName(operand) << ";\n";
287   }
288 
289   os << "goto ";
290   if (!(emitter.hasBlockLabel(trueSuccessor))) {
291     return condBranchOp.emitOpError("unable to find label for successor block");
292   }
293   os << emitter.getOrCreateName(trueSuccessor) << ";\n";
294   os.unindent() << "} else {\n";
295   os.indent();
296   // If condition is false.
297   for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
298                              falseSuccessor.getArguments())) {
299     Value &operand = std::get<0>(pair);
300     BlockArgument &argument = std::get<1>(pair);
301     os << emitter.getOrCreateName(argument) << " = "
302        << emitter.getOrCreateName(operand) << ";\n";
303   }
304 
305   os << "goto ";
306   if (!(emitter.hasBlockLabel(falseSuccessor))) {
307     return condBranchOp.emitOpError()
308            << "unable to find label for successor block";
309   }
310   os << emitter.getOrCreateName(falseSuccessor) << ";\n";
311   os.unindent() << "}";
312   return success();
313 }
314 
printOperation(CppEmitter & emitter,func::CallOp callOp)315 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
316   if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
317     return failure();
318 
319   raw_ostream &os = emitter.ostream();
320   os << callOp.getCallee() << "(";
321   if (failed(emitter.emitOperands(*callOp.getOperation())))
322     return failure();
323   os << ")";
324   return success();
325 }
326 
printOperation(CppEmitter & emitter,emitc::CallOp callOp)327 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
328   raw_ostream &os = emitter.ostream();
329   Operation &op = *callOp.getOperation();
330 
331   if (failed(emitter.emitAssignPrefix(op)))
332     return failure();
333   os << callOp.getCallee();
334 
335   auto emitArgs = [&](Attribute attr) -> LogicalResult {
336     if (auto t = attr.dyn_cast<IntegerAttr>()) {
337       // Index attributes are treated specially as operand index.
338       if (t.getType().isIndex()) {
339         int64_t idx = t.getInt();
340         if ((idx < 0) || (idx >= op.getNumOperands()))
341           return op.emitOpError("invalid operand index");
342         if (!emitter.hasValueInScope(op.getOperand(idx)))
343           return op.emitOpError("operand ")
344                  << idx << "'s value not defined in scope";
345         os << emitter.getOrCreateName(op.getOperand(idx));
346         return success();
347       }
348     }
349     if (failed(emitter.emitAttribute(op.getLoc(), attr)))
350       return failure();
351 
352     return success();
353   };
354 
355   if (callOp.getTemplateArgs()) {
356     os << "<";
357     if (failed(
358             interleaveCommaWithError(*callOp.getTemplateArgs(), os, emitArgs)))
359       return failure();
360     os << ">";
361   }
362 
363   os << "(";
364 
365   LogicalResult emittedArgs =
366       callOp.getArgs()
367           ? interleaveCommaWithError(*callOp.getArgs(), os, emitArgs)
368           : emitter.emitOperands(op);
369   if (failed(emittedArgs))
370     return failure();
371   os << ")";
372   return success();
373 }
374 
printOperation(CppEmitter & emitter,emitc::ApplyOp applyOp)375 static LogicalResult printOperation(CppEmitter &emitter,
376                                     emitc::ApplyOp applyOp) {
377   raw_ostream &os = emitter.ostream();
378   Operation &op = *applyOp.getOperation();
379 
380   if (failed(emitter.emitAssignPrefix(op)))
381     return failure();
382   os << applyOp.getApplicableOperator();
383   os << emitter.getOrCreateName(applyOp.getOperand());
384 
385   return success();
386 }
387 
printOperation(CppEmitter & emitter,emitc::CastOp castOp)388 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
389   raw_ostream &os = emitter.ostream();
390   Operation &op = *castOp.getOperation();
391 
392   if (failed(emitter.emitAssignPrefix(op)))
393     return failure();
394   os << "(";
395   if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
396     return failure();
397   os << ") ";
398   os << emitter.getOrCreateName(castOp.getOperand());
399 
400   return success();
401 }
402 
printOperation(CppEmitter & emitter,emitc::IncludeOp includeOp)403 static LogicalResult printOperation(CppEmitter &emitter,
404                                     emitc::IncludeOp includeOp) {
405   raw_ostream &os = emitter.ostream();
406 
407   os << "#include ";
408   if (includeOp.getIsStandardInclude())
409     os << "<" << includeOp.getInclude() << ">";
410   else
411     os << "\"" << includeOp.getInclude() << "\"";
412 
413   return success();
414 }
415 
printOperation(CppEmitter & emitter,scf::ForOp forOp)416 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
417 
418   raw_indented_ostream &os = emitter.ostream();
419 
420   OperandRange operands = forOp.getIterOperands();
421   Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
422   Operation::result_range results = forOp.getResults();
423 
424   if (!emitter.shouldDeclareVariablesAtTop()) {
425     for (OpResult result : results) {
426       if (failed(emitter.emitVariableDeclaration(result,
427                                                  /*trailingSemicolon=*/true)))
428         return failure();
429     }
430   }
431 
432   for (auto pair : llvm::zip(iterArgs, operands)) {
433     if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
434       return failure();
435     os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
436     os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
437     os << "\n";
438   }
439 
440   os << "for (";
441   if (failed(
442           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
443     return failure();
444   os << " ";
445   os << emitter.getOrCreateName(forOp.getInductionVar());
446   os << " = ";
447   os << emitter.getOrCreateName(forOp.getLowerBound());
448   os << "; ";
449   os << emitter.getOrCreateName(forOp.getInductionVar());
450   os << " < ";
451   os << emitter.getOrCreateName(forOp.getUpperBound());
452   os << "; ";
453   os << emitter.getOrCreateName(forOp.getInductionVar());
454   os << " += ";
455   os << emitter.getOrCreateName(forOp.getStep());
456   os << ") {\n";
457   os.indent();
458 
459   Region &forRegion = forOp.getRegion();
460   auto regionOps = forRegion.getOps();
461 
462   // We skip the trailing yield op because this updates the result variables
463   // of the for op in the generated code. Instead we update the iterArgs at
464   // the end of a loop iteration and set the result variables after the for
465   // loop.
466   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
467     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
468       return failure();
469   }
470 
471   Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
472   // Copy yield operands into iterArgs at the end of a loop iteration.
473   for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
474     BlockArgument iterArg = std::get<0>(pair);
475     Value operand = std::get<1>(pair);
476     os << emitter.getOrCreateName(iterArg) << " = "
477        << emitter.getOrCreateName(operand) << ";\n";
478   }
479 
480   os.unindent() << "}";
481 
482   // Copy iterArgs into results after the for loop.
483   for (auto pair : llvm::zip(results, iterArgs)) {
484     OpResult result = std::get<0>(pair);
485     BlockArgument iterArg = std::get<1>(pair);
486     os << "\n"
487        << emitter.getOrCreateName(result) << " = "
488        << emitter.getOrCreateName(iterArg) << ";";
489   }
490 
491   return success();
492 }
493 
printOperation(CppEmitter & emitter,scf::IfOp ifOp)494 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
495   raw_indented_ostream &os = emitter.ostream();
496 
497   if (!emitter.shouldDeclareVariablesAtTop()) {
498     for (OpResult result : ifOp.getResults()) {
499       if (failed(emitter.emitVariableDeclaration(result,
500                                                  /*trailingSemicolon=*/true)))
501         return failure();
502     }
503   }
504 
505   os << "if (";
506   if (failed(emitter.emitOperands(*ifOp.getOperation())))
507     return failure();
508   os << ") {\n";
509   os.indent();
510 
511   Region &thenRegion = ifOp.getThenRegion();
512   for (Operation &op : thenRegion.getOps()) {
513     // Note: This prints a superfluous semicolon if the terminating yield op has
514     // zero results.
515     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
516       return failure();
517   }
518 
519   os.unindent() << "}";
520 
521   Region &elseRegion = ifOp.getElseRegion();
522   if (!elseRegion.empty()) {
523     os << " else {\n";
524     os.indent();
525 
526     for (Operation &op : elseRegion.getOps()) {
527       // Note: This prints a superfluous semicolon if the terminating yield op
528       // has zero results.
529       if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
530         return failure();
531     }
532 
533     os.unindent() << "}";
534   }
535 
536   return success();
537 }
538 
printOperation(CppEmitter & emitter,scf::YieldOp yieldOp)539 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
540   raw_ostream &os = emitter.ostream();
541   Operation &parentOp = *yieldOp.getOperation()->getParentOp();
542 
543   if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
544     return yieldOp.emitError("number of operands does not to match the number "
545                              "of the parent op's results");
546   }
547 
548   if (failed(interleaveWithError(
549           llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
550           [&](auto pair) -> LogicalResult {
551             auto result = std::get<0>(pair);
552             auto operand = std::get<1>(pair);
553             os << emitter.getOrCreateName(result) << " = ";
554 
555             if (!emitter.hasValueInScope(operand))
556               return yieldOp.emitError("operand value not in scope");
557             os << emitter.getOrCreateName(operand);
558             return success();
559           },
560           [&]() { os << ";\n"; })))
561     return failure();
562 
563   return success();
564 }
565 
printOperation(CppEmitter & emitter,func::ReturnOp returnOp)566 static LogicalResult printOperation(CppEmitter &emitter,
567                                     func::ReturnOp returnOp) {
568   raw_ostream &os = emitter.ostream();
569   os << "return";
570   switch (returnOp.getNumOperands()) {
571   case 0:
572     return success();
573   case 1:
574     os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
575     return success(emitter.hasValueInScope(returnOp.getOperand(0)));
576   default:
577     os << " std::make_tuple(";
578     if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
579       return failure();
580     os << ")";
581     return success();
582   }
583 }
584 
printOperation(CppEmitter & emitter,ModuleOp moduleOp)585 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
586   CppEmitter::Scope scope(emitter);
587 
588   for (Operation &op : moduleOp) {
589     if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
590       return failure();
591   }
592   return success();
593 }
594 
printOperation(CppEmitter & emitter,func::FuncOp functionOp)595 static LogicalResult printOperation(CppEmitter &emitter,
596                                     func::FuncOp functionOp) {
597   // We need to declare variables at top if the function has multiple blocks.
598   if (!emitter.shouldDeclareVariablesAtTop() &&
599       functionOp.getBlocks().size() > 1) {
600     return functionOp.emitOpError(
601         "with multiple blocks needs variables declared at top");
602   }
603 
604   CppEmitter::Scope scope(emitter);
605   raw_indented_ostream &os = emitter.ostream();
606   if (failed(emitter.emitTypes(functionOp.getLoc(),
607                                functionOp.getFunctionType().getResults())))
608     return failure();
609   os << " " << functionOp.getName();
610 
611   os << "(";
612   if (failed(interleaveCommaWithError(
613           functionOp.getArguments(), os,
614           [&](BlockArgument arg) -> LogicalResult {
615             if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
616               return failure();
617             os << " " << emitter.getOrCreateName(arg);
618             return success();
619           })))
620     return failure();
621   os << ") {\n";
622   os.indent();
623   if (emitter.shouldDeclareVariablesAtTop()) {
624     // Declare all variables that hold op results including those from nested
625     // regions.
626     WalkResult result =
627         functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
628           for (OpResult result : op->getResults()) {
629             if (failed(emitter.emitVariableDeclaration(
630                     result, /*trailingSemicolon=*/true))) {
631               return WalkResult(
632                   op->emitError("unable to declare result variable for op"));
633             }
634           }
635           return WalkResult::advance();
636         });
637     if (result.wasInterrupted())
638       return failure();
639   }
640 
641   Region::BlockListType &blocks = functionOp.getBlocks();
642   // Create label names for basic blocks.
643   for (Block &block : blocks) {
644     emitter.getOrCreateName(block);
645   }
646 
647   // Declare variables for basic block arguments.
648   for (Block &block : llvm::drop_begin(blocks)) {
649     for (BlockArgument &arg : block.getArguments()) {
650       if (emitter.hasValueInScope(arg))
651         return functionOp.emitOpError(" block argument #")
652                << arg.getArgNumber() << " is out of scope";
653       if (failed(
654               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
655         return failure();
656       }
657       os << " " << emitter.getOrCreateName(arg) << ";\n";
658     }
659   }
660 
661   for (Block &block : blocks) {
662     // Only print a label if the block has predecessors.
663     if (!block.hasNoPredecessors()) {
664       if (failed(emitter.emitLabel(block)))
665         return failure();
666     }
667     for (Operation &op : block.getOperations()) {
668       // When generating code for an scf.if or cf.cond_br op no semicolon needs
669       // to be printed after the closing brace.
670       // When generating code for an scf.for op, printing a trailing semicolon
671       // is handled within the printOperation function.
672       bool trailingSemicolon =
673           !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
674 
675       if (failed(emitter.emitOperation(
676               op, /*trailingSemicolon=*/trailingSemicolon)))
677         return failure();
678     }
679   }
680   os.unindent() << "}\n";
681   return success();
682 }
683 
CppEmitter(raw_ostream & os,bool declareVariablesAtTop)684 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
685     : os(os), declareVariablesAtTop(declareVariablesAtTop) {
686   valueInScopeCount.push(0);
687   labelInScopeCount.push(0);
688 }
689 
690 /// Return the existing or a new name for a Value.
getOrCreateName(Value val)691 StringRef CppEmitter::getOrCreateName(Value val) {
692   if (!valueMapper.count(val))
693     valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
694   return *valueMapper.begin(val);
695 }
696 
697 /// Return the existing or a new label for a Block.
getOrCreateName(Block & block)698 StringRef CppEmitter::getOrCreateName(Block &block) {
699   if (!blockMapper.count(&block))
700     blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
701   return *blockMapper.begin(&block);
702 }
703 
shouldMapToUnsigned(IntegerType::SignednessSemantics val)704 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
705   switch (val) {
706   case IntegerType::Signless:
707     return false;
708   case IntegerType::Signed:
709     return false;
710   case IntegerType::Unsigned:
711     return true;
712   }
713   llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
714 }
715 
hasValueInScope(Value val)716 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
717 
hasBlockLabel(Block & block)718 bool CppEmitter::hasBlockLabel(Block &block) {
719   return blockMapper.count(&block);
720 }
721 
emitAttribute(Location loc,Attribute attr)722 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
723   auto printInt = [&](const APInt &val, bool isUnsigned) {
724     if (val.getBitWidth() == 1) {
725       if (val.getBoolValue())
726         os << "true";
727       else
728         os << "false";
729     } else {
730       SmallString<128> strValue;
731       val.toString(strValue, 10, !isUnsigned, false);
732       os << strValue;
733     }
734   };
735 
736   auto printFloat = [&](const APFloat &val) {
737     if (val.isFinite()) {
738       SmallString<128> strValue;
739       // Use default values of toString except don't truncate zeros.
740       val.toString(strValue, 0, 0, false);
741       switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
742       case llvm::APFloatBase::S_IEEEsingle:
743         os << "(float)";
744         break;
745       case llvm::APFloatBase::S_IEEEdouble:
746         os << "(double)";
747         break;
748       default:
749         break;
750       };
751       os << strValue;
752     } else if (val.isNaN()) {
753       os << "NAN";
754     } else if (val.isInfinity()) {
755       if (val.isNegative())
756         os << "-";
757       os << "INFINITY";
758     }
759   };
760 
761   // Print floating point attributes.
762   if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
763     printFloat(fAttr.getValue());
764     return success();
765   }
766   if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
767     os << '{';
768     interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
769     os << '}';
770     return success();
771   }
772 
773   // Print integer attributes.
774   if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
775     if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
776       printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
777       return success();
778     }
779     if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
780       printInt(iAttr.getValue(), false);
781       return success();
782     }
783   }
784   if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
785     if (auto iType = dense.getType()
786                          .cast<TensorType>()
787                          .getElementType()
788                          .dyn_cast<IntegerType>()) {
789       os << '{';
790       interleaveComma(dense, os, [&](const APInt &val) {
791         printInt(val, shouldMapToUnsigned(iType.getSignedness()));
792       });
793       os << '}';
794       return success();
795     }
796     if (auto iType = dense.getType()
797                          .cast<TensorType>()
798                          .getElementType()
799                          .dyn_cast<IndexType>()) {
800       os << '{';
801       interleaveComma(dense, os,
802                       [&](const APInt &val) { printInt(val, false); });
803       os << '}';
804       return success();
805     }
806   }
807 
808   // Print opaque attributes.
809   if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
810     os << oAttr.getValue();
811     return success();
812   }
813 
814   // Print symbolic reference attributes.
815   if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
816     if (sAttr.getNestedReferences().size() > 1)
817       return emitError(loc, "attribute has more than 1 nested reference");
818     os << sAttr.getRootReference().getValue();
819     return success();
820   }
821 
822   // Print type attributes.
823   if (auto type = attr.dyn_cast<TypeAttr>())
824     return emitType(loc, type.getValue());
825 
826   return emitError(loc, "cannot emit attribute of type ") << attr.getType();
827 }
828 
emitOperands(Operation & op)829 LogicalResult CppEmitter::emitOperands(Operation &op) {
830   auto emitOperandName = [&](Value result) -> LogicalResult {
831     if (!hasValueInScope(result))
832       return op.emitOpError() << "operand value not in scope";
833     os << getOrCreateName(result);
834     return success();
835   };
836   return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
837 }
838 
839 LogicalResult
emitOperandsAndAttributes(Operation & op,ArrayRef<StringRef> exclude)840 CppEmitter::emitOperandsAndAttributes(Operation &op,
841                                       ArrayRef<StringRef> exclude) {
842   if (failed(emitOperands(op)))
843     return failure();
844   // Insert comma in between operands and non-filtered attributes if needed.
845   if (op.getNumOperands() > 0) {
846     for (NamedAttribute attr : op.getAttrs()) {
847       if (!llvm::is_contained(exclude, attr.getName().strref())) {
848         os << ", ";
849         break;
850       }
851     }
852   }
853   // Emit attributes.
854   auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
855     if (llvm::is_contained(exclude, attr.getName().strref()))
856       return success();
857     os << "/* " << attr.getName().getValue() << " */";
858     if (failed(emitAttribute(op.getLoc(), attr.getValue())))
859       return failure();
860     return success();
861   };
862   return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
863 }
864 
emitVariableAssignment(OpResult result)865 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
866   if (!hasValueInScope(result)) {
867     return result.getDefiningOp()->emitOpError(
868         "result variable for the operation has not been declared");
869   }
870   os << getOrCreateName(result) << " = ";
871   return success();
872 }
873 
emitVariableDeclaration(OpResult result,bool trailingSemicolon)874 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
875                                                   bool trailingSemicolon) {
876   if (hasValueInScope(result)) {
877     return result.getDefiningOp()->emitError(
878         "result variable for the operation already declared");
879   }
880   if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
881     return failure();
882   os << " " << getOrCreateName(result);
883   if (trailingSemicolon)
884     os << ";\n";
885   return success();
886 }
887 
emitAssignPrefix(Operation & op)888 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
889   switch (op.getNumResults()) {
890   case 0:
891     break;
892   case 1: {
893     OpResult result = op.getResult(0);
894     if (shouldDeclareVariablesAtTop()) {
895       if (failed(emitVariableAssignment(result)))
896         return failure();
897     } else {
898       if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
899         return failure();
900       os << " = ";
901     }
902     break;
903   }
904   default:
905     if (!shouldDeclareVariablesAtTop()) {
906       for (OpResult result : op.getResults()) {
907         if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
908           return failure();
909       }
910     }
911     os << "std::tie(";
912     interleaveComma(op.getResults(), os,
913                     [&](Value result) { os << getOrCreateName(result); });
914     os << ") = ";
915   }
916   return success();
917 }
918 
emitLabel(Block & block)919 LogicalResult CppEmitter::emitLabel(Block &block) {
920   if (!hasBlockLabel(block))
921     return block.getParentOp()->emitError("label for block not found");
922   // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
923   // label instead of using `getOStream`.
924   os.getOStream() << getOrCreateName(block) << ":\n";
925   return success();
926 }
927 
emitOperation(Operation & op,bool trailingSemicolon)928 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
929   LogicalResult status =
930       llvm::TypeSwitch<Operation *, LogicalResult>(&op)
931           // Builtin ops.
932           .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
933           // CF ops.
934           .Case<cf::BranchOp, cf::CondBranchOp>(
935               [&](auto op) { return printOperation(*this, op); })
936           // EmitC ops.
937           .Case<emitc::ApplyOp, emitc::CallOp, emitc::CastOp, emitc::ConstantOp,
938                 emitc::IncludeOp, emitc::VariableOp>(
939               [&](auto op) { return printOperation(*this, op); })
940           // Func ops.
941           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
942               [&](auto op) { return printOperation(*this, op); })
943           // SCF ops.
944           .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
945               [&](auto op) { return printOperation(*this, op); })
946           // Arithmetic ops.
947           .Case<arith::ConstantOp>(
948               [&](auto op) { return printOperation(*this, op); })
949           .Default([&](Operation *) {
950             return op.emitOpError("unable to find printer for op");
951           });
952 
953   if (failed(status))
954     return failure();
955   os << (trailingSemicolon ? ";\n" : "\n");
956   return success();
957 }
958 
emitType(Location loc,Type type)959 LogicalResult CppEmitter::emitType(Location loc, Type type) {
960   if (auto iType = type.dyn_cast<IntegerType>()) {
961     switch (iType.getWidth()) {
962     case 1:
963       return (os << "bool"), success();
964     case 8:
965     case 16:
966     case 32:
967     case 64:
968       if (shouldMapToUnsigned(iType.getSignedness()))
969         return (os << "uint" << iType.getWidth() << "_t"), success();
970       else
971         return (os << "int" << iType.getWidth() << "_t"), success();
972     default:
973       return emitError(loc, "cannot emit integer type ") << type;
974     }
975   }
976   if (auto fType = type.dyn_cast<FloatType>()) {
977     switch (fType.getWidth()) {
978     case 32:
979       return (os << "float"), success();
980     case 64:
981       return (os << "double"), success();
982     default:
983       return emitError(loc, "cannot emit float type ") << type;
984     }
985   }
986   if (auto iType = type.dyn_cast<IndexType>())
987     return (os << "size_t"), success();
988   if (auto tType = type.dyn_cast<TensorType>()) {
989     if (!tType.hasRank())
990       return emitError(loc, "cannot emit unranked tensor type");
991     if (!tType.hasStaticShape())
992       return emitError(loc, "cannot emit tensor type with non static shape");
993     os << "Tensor<";
994     if (failed(emitType(loc, tType.getElementType())))
995       return failure();
996     auto shape = tType.getShape();
997     for (auto dimSize : shape) {
998       os << ", ";
999       os << dimSize;
1000     }
1001     os << ">";
1002     return success();
1003   }
1004   if (auto tType = type.dyn_cast<TupleType>())
1005     return emitTupleType(loc, tType.getTypes());
1006   if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
1007     os << oType.getValue();
1008     return success();
1009   }
1010   if (auto pType = type.dyn_cast<emitc::PointerType>()) {
1011     if (failed(emitType(loc, pType.getPointee())))
1012       return failure();
1013     os << "*";
1014     return success();
1015   }
1016   return emitError(loc, "cannot emit type ") << type;
1017 }
1018 
emitTypes(Location loc,ArrayRef<Type> types)1019 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1020   switch (types.size()) {
1021   case 0:
1022     os << "void";
1023     return success();
1024   case 1:
1025     return emitType(loc, types.front());
1026   default:
1027     return emitTupleType(loc, types);
1028   }
1029 }
1030 
emitTupleType(Location loc,ArrayRef<Type> types)1031 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1032   os << "std::tuple<";
1033   if (failed(interleaveCommaWithError(
1034           types, os, [&](Type type) { return emitType(loc, type); })))
1035     return failure();
1036   os << ">";
1037   return success();
1038 }
1039 
translateToCpp(Operation * op,raw_ostream & os,bool declareVariablesAtTop)1040 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1041                                     bool declareVariablesAtTop) {
1042   CppEmitter emitter(os, declareVariablesAtTop);
1043   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1044 }
1045