1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11 #include "mlir/Dialect/EmitC/IR/EmitC.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/SCF/IR/SCF.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/IndentedOstream.h"
19 #include "mlir/Target/Cpp/CppEmitter.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include <utility>
27
28 #define DEBUG_TYPE "translate-to-cpp"
29
30 using namespace mlir;
31 using namespace mlir::emitc;
32 using llvm::formatv;
33
34 /// Convenience functions to produce interleaved output with functions returning
35 /// a LogicalResult. This is different than those in STLExtras as functions used
36 /// on each element doesn't return a string.
37 template <typename ForwardIterator, typename UnaryFunctor,
38 typename NullaryFunctor>
39 inline LogicalResult
interleaveWithError(ForwardIterator begin,ForwardIterator end,UnaryFunctor eachFn,NullaryFunctor betweenFn)40 interleaveWithError(ForwardIterator begin, ForwardIterator end,
41 UnaryFunctor eachFn, NullaryFunctor betweenFn) {
42 if (begin == end)
43 return success();
44 if (failed(eachFn(*begin)))
45 return failure();
46 ++begin;
47 for (; begin != end; ++begin) {
48 betweenFn();
49 if (failed(eachFn(*begin)))
50 return failure();
51 }
52 return success();
53 }
54
55 template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
interleaveWithError(const Container & c,UnaryFunctor eachFn,NullaryFunctor betweenFn)56 inline LogicalResult interleaveWithError(const Container &c,
57 UnaryFunctor eachFn,
58 NullaryFunctor betweenFn) {
59 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
60 }
61
62 template <typename Container, typename UnaryFunctor>
interleaveCommaWithError(const Container & c,raw_ostream & os,UnaryFunctor eachFn)63 inline LogicalResult interleaveCommaWithError(const Container &c,
64 raw_ostream &os,
65 UnaryFunctor eachFn) {
66 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
67 }
68
69 namespace {
70 /// Emitter that uses dialect specific emitters to emit C++ code.
71 struct CppEmitter {
72 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
73
74 /// Emits attribute or returns failure.
75 LogicalResult emitAttribute(Location loc, Attribute attr);
76
77 /// Emits operation 'op' with/without training semicolon or returns failure.
78 LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
79
80 /// Emits type 'type' or returns failure.
81 LogicalResult emitType(Location loc, Type type);
82
83 /// Emits array of types as a std::tuple of the emitted types.
84 /// - emits void for an empty array;
85 /// - emits the type of the only element for arrays of size one;
86 /// - emits a std::tuple otherwise;
87 LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
88
89 /// Emits array of types as a std::tuple of the emitted types independently of
90 /// the array size.
91 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
92
93 /// Emits an assignment for a variable which has been declared previously.
94 LogicalResult emitVariableAssignment(OpResult result);
95
96 /// Emits a variable declaration for a result of an operation.
97 LogicalResult emitVariableDeclaration(OpResult result,
98 bool trailingSemicolon);
99
100 /// Emits the variable declaration and assignment prefix for 'op'.
101 /// - emits separate variable followed by std::tie for multi-valued operation;
102 /// - emits single type followed by variable for single result;
103 /// - emits nothing if no value produced by op;
104 /// Emits final '=' operator where a type is produced. Returns failure if
105 /// any result type could not be converted.
106 LogicalResult emitAssignPrefix(Operation &op);
107
108 /// Emits a label for the block.
109 LogicalResult emitLabel(Block &block);
110
111 /// Emits the operands and atttributes of the operation. All operands are
112 /// emitted first and then all attributes in alphabetical order.
113 LogicalResult emitOperandsAndAttributes(Operation &op,
114 ArrayRef<StringRef> exclude = {});
115
116 /// Emits the operands of the operation. All operands are emitted in order.
117 LogicalResult emitOperands(Operation &op);
118
119 /// Return the existing or a new name for a Value.
120 StringRef getOrCreateName(Value val);
121
122 /// Return the existing or a new label of a Block.
123 StringRef getOrCreateName(Block &block);
124
125 /// Whether to map an mlir integer to a unsigned integer in C++.
126 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
127
128 /// RAII helper function to manage entering/exiting C++ scopes.
129 struct Scope {
Scope__anon47891eb50211::CppEmitter::Scope130 Scope(CppEmitter &emitter)
131 : valueMapperScope(emitter.valueMapper),
132 blockMapperScope(emitter.blockMapper), emitter(emitter) {
133 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
134 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
135 }
~Scope__anon47891eb50211::CppEmitter::Scope136 ~Scope() {
137 emitter.valueInScopeCount.pop();
138 emitter.labelInScopeCount.pop();
139 }
140
141 private:
142 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
143 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
144 CppEmitter &emitter;
145 };
146
147 /// Returns wether the Value is assigned to a C++ variable in the scope.
148 bool hasValueInScope(Value val);
149
150 // Returns whether a label is assigned to the block.
151 bool hasBlockLabel(Block &block);
152
153 /// Returns the output stream.
ostream__anon47891eb50211::CppEmitter154 raw_indented_ostream &ostream() { return os; };
155
156 /// Returns if all variables for op results and basic block arguments need to
157 /// be declared at the beginning of a function.
shouldDeclareVariablesAtTop__anon47891eb50211::CppEmitter158 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
159
160 private:
161 using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
162 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
163
164 /// Output stream to emit to.
165 raw_indented_ostream os;
166
167 /// Boolean to enforce that all variables for op results and block
168 /// arguments are declared at the beginning of the function. This also
169 /// includes results from ops located in nested regions.
170 bool declareVariablesAtTop;
171
172 /// Map from value to name of C++ variable that contain the name.
173 ValueMapper valueMapper;
174
175 /// Map from block to name of C++ label.
176 BlockMapper blockMapper;
177
178 /// The number of values in the current scope. This is used to declare the
179 /// names of values in a scope.
180 std::stack<int64_t> valueInScopeCount;
181 std::stack<int64_t> labelInScopeCount;
182 };
183 } // namespace
184
printConstantOp(CppEmitter & emitter,Operation * operation,Attribute value)185 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
186 Attribute value) {
187 OpResult result = operation->getResult(0);
188
189 // Only emit an assignment as the variable was already declared when printing
190 // the FuncOp.
191 if (emitter.shouldDeclareVariablesAtTop()) {
192 // Skip the assignment if the emitc.constant has no value.
193 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
194 if (oAttr.getValue().empty())
195 return success();
196 }
197
198 if (failed(emitter.emitVariableAssignment(result)))
199 return failure();
200 return emitter.emitAttribute(operation->getLoc(), value);
201 }
202
203 // Emit a variable declaration for an emitc.constant op without value.
204 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
205 if (oAttr.getValue().empty())
206 // The semicolon gets printed by the emitOperation function.
207 return emitter.emitVariableDeclaration(result,
208 /*trailingSemicolon=*/false);
209 }
210
211 // Emit a variable declaration.
212 if (failed(emitter.emitAssignPrefix(*operation)))
213 return failure();
214 return emitter.emitAttribute(operation->getLoc(), value);
215 }
216
printOperation(CppEmitter & emitter,emitc::ConstantOp constantOp)217 static LogicalResult printOperation(CppEmitter &emitter,
218 emitc::ConstantOp constantOp) {
219 Operation *operation = constantOp.getOperation();
220 Attribute value = constantOp.getValue();
221
222 return printConstantOp(emitter, operation, value);
223 }
224
printOperation(CppEmitter & emitter,emitc::VariableOp variableOp)225 static LogicalResult printOperation(CppEmitter &emitter,
226 emitc::VariableOp variableOp) {
227 Operation *operation = variableOp.getOperation();
228 Attribute value = variableOp.getValue();
229
230 return printConstantOp(emitter, operation, value);
231 }
232
printOperation(CppEmitter & emitter,arith::ConstantOp constantOp)233 static LogicalResult printOperation(CppEmitter &emitter,
234 arith::ConstantOp constantOp) {
235 Operation *operation = constantOp.getOperation();
236 Attribute value = constantOp.getValue();
237
238 return printConstantOp(emitter, operation, value);
239 }
240
printOperation(CppEmitter & emitter,func::ConstantOp constantOp)241 static LogicalResult printOperation(CppEmitter &emitter,
242 func::ConstantOp constantOp) {
243 Operation *operation = constantOp.getOperation();
244 Attribute value = constantOp.getValueAttr();
245
246 return printConstantOp(emitter, operation, value);
247 }
248
printOperation(CppEmitter & emitter,cf::BranchOp branchOp)249 static LogicalResult printOperation(CppEmitter &emitter,
250 cf::BranchOp branchOp) {
251 raw_ostream &os = emitter.ostream();
252 Block &successor = *branchOp.getSuccessor();
253
254 for (auto pair :
255 llvm::zip(branchOp.getOperands(), successor.getArguments())) {
256 Value &operand = std::get<0>(pair);
257 BlockArgument &argument = std::get<1>(pair);
258 os << emitter.getOrCreateName(argument) << " = "
259 << emitter.getOrCreateName(operand) << ";\n";
260 }
261
262 os << "goto ";
263 if (!(emitter.hasBlockLabel(successor)))
264 return branchOp.emitOpError("unable to find label for successor block");
265 os << emitter.getOrCreateName(successor);
266 return success();
267 }
268
printOperation(CppEmitter & emitter,cf::CondBranchOp condBranchOp)269 static LogicalResult printOperation(CppEmitter &emitter,
270 cf::CondBranchOp condBranchOp) {
271 raw_indented_ostream &os = emitter.ostream();
272 Block &trueSuccessor = *condBranchOp.getTrueDest();
273 Block &falseSuccessor = *condBranchOp.getFalseDest();
274
275 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
276 << ") {\n";
277
278 os.indent();
279
280 // If condition is true.
281 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
282 trueSuccessor.getArguments())) {
283 Value &operand = std::get<0>(pair);
284 BlockArgument &argument = std::get<1>(pair);
285 os << emitter.getOrCreateName(argument) << " = "
286 << emitter.getOrCreateName(operand) << ";\n";
287 }
288
289 os << "goto ";
290 if (!(emitter.hasBlockLabel(trueSuccessor))) {
291 return condBranchOp.emitOpError("unable to find label for successor block");
292 }
293 os << emitter.getOrCreateName(trueSuccessor) << ";\n";
294 os.unindent() << "} else {\n";
295 os.indent();
296 // If condition is false.
297 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
298 falseSuccessor.getArguments())) {
299 Value &operand = std::get<0>(pair);
300 BlockArgument &argument = std::get<1>(pair);
301 os << emitter.getOrCreateName(argument) << " = "
302 << emitter.getOrCreateName(operand) << ";\n";
303 }
304
305 os << "goto ";
306 if (!(emitter.hasBlockLabel(falseSuccessor))) {
307 return condBranchOp.emitOpError()
308 << "unable to find label for successor block";
309 }
310 os << emitter.getOrCreateName(falseSuccessor) << ";\n";
311 os.unindent() << "}";
312 return success();
313 }
314
printOperation(CppEmitter & emitter,func::CallOp callOp)315 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
316 if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
317 return failure();
318
319 raw_ostream &os = emitter.ostream();
320 os << callOp.getCallee() << "(";
321 if (failed(emitter.emitOperands(*callOp.getOperation())))
322 return failure();
323 os << ")";
324 return success();
325 }
326
printOperation(CppEmitter & emitter,emitc::CallOp callOp)327 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
328 raw_ostream &os = emitter.ostream();
329 Operation &op = *callOp.getOperation();
330
331 if (failed(emitter.emitAssignPrefix(op)))
332 return failure();
333 os << callOp.getCallee();
334
335 auto emitArgs = [&](Attribute attr) -> LogicalResult {
336 if (auto t = attr.dyn_cast<IntegerAttr>()) {
337 // Index attributes are treated specially as operand index.
338 if (t.getType().isIndex()) {
339 int64_t idx = t.getInt();
340 if ((idx < 0) || (idx >= op.getNumOperands()))
341 return op.emitOpError("invalid operand index");
342 if (!emitter.hasValueInScope(op.getOperand(idx)))
343 return op.emitOpError("operand ")
344 << idx << "'s value not defined in scope";
345 os << emitter.getOrCreateName(op.getOperand(idx));
346 return success();
347 }
348 }
349 if (failed(emitter.emitAttribute(op.getLoc(), attr)))
350 return failure();
351
352 return success();
353 };
354
355 if (callOp.getTemplateArgs()) {
356 os << "<";
357 if (failed(
358 interleaveCommaWithError(*callOp.getTemplateArgs(), os, emitArgs)))
359 return failure();
360 os << ">";
361 }
362
363 os << "(";
364
365 LogicalResult emittedArgs =
366 callOp.getArgs()
367 ? interleaveCommaWithError(*callOp.getArgs(), os, emitArgs)
368 : emitter.emitOperands(op);
369 if (failed(emittedArgs))
370 return failure();
371 os << ")";
372 return success();
373 }
374
printOperation(CppEmitter & emitter,emitc::ApplyOp applyOp)375 static LogicalResult printOperation(CppEmitter &emitter,
376 emitc::ApplyOp applyOp) {
377 raw_ostream &os = emitter.ostream();
378 Operation &op = *applyOp.getOperation();
379
380 if (failed(emitter.emitAssignPrefix(op)))
381 return failure();
382 os << applyOp.getApplicableOperator();
383 os << emitter.getOrCreateName(applyOp.getOperand());
384
385 return success();
386 }
387
printOperation(CppEmitter & emitter,emitc::CastOp castOp)388 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
389 raw_ostream &os = emitter.ostream();
390 Operation &op = *castOp.getOperation();
391
392 if (failed(emitter.emitAssignPrefix(op)))
393 return failure();
394 os << "(";
395 if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
396 return failure();
397 os << ") ";
398 os << emitter.getOrCreateName(castOp.getOperand());
399
400 return success();
401 }
402
printOperation(CppEmitter & emitter,emitc::IncludeOp includeOp)403 static LogicalResult printOperation(CppEmitter &emitter,
404 emitc::IncludeOp includeOp) {
405 raw_ostream &os = emitter.ostream();
406
407 os << "#include ";
408 if (includeOp.getIsStandardInclude())
409 os << "<" << includeOp.getInclude() << ">";
410 else
411 os << "\"" << includeOp.getInclude() << "\"";
412
413 return success();
414 }
415
printOperation(CppEmitter & emitter,scf::ForOp forOp)416 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
417
418 raw_indented_ostream &os = emitter.ostream();
419
420 OperandRange operands = forOp.getIterOperands();
421 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
422 Operation::result_range results = forOp.getResults();
423
424 if (!emitter.shouldDeclareVariablesAtTop()) {
425 for (OpResult result : results) {
426 if (failed(emitter.emitVariableDeclaration(result,
427 /*trailingSemicolon=*/true)))
428 return failure();
429 }
430 }
431
432 for (auto pair : llvm::zip(iterArgs, operands)) {
433 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
434 return failure();
435 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
436 os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
437 os << "\n";
438 }
439
440 os << "for (";
441 if (failed(
442 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
443 return failure();
444 os << " ";
445 os << emitter.getOrCreateName(forOp.getInductionVar());
446 os << " = ";
447 os << emitter.getOrCreateName(forOp.getLowerBound());
448 os << "; ";
449 os << emitter.getOrCreateName(forOp.getInductionVar());
450 os << " < ";
451 os << emitter.getOrCreateName(forOp.getUpperBound());
452 os << "; ";
453 os << emitter.getOrCreateName(forOp.getInductionVar());
454 os << " += ";
455 os << emitter.getOrCreateName(forOp.getStep());
456 os << ") {\n";
457 os.indent();
458
459 Region &forRegion = forOp.getRegion();
460 auto regionOps = forRegion.getOps();
461
462 // We skip the trailing yield op because this updates the result variables
463 // of the for op in the generated code. Instead we update the iterArgs at
464 // the end of a loop iteration and set the result variables after the for
465 // loop.
466 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
467 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
468 return failure();
469 }
470
471 Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
472 // Copy yield operands into iterArgs at the end of a loop iteration.
473 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
474 BlockArgument iterArg = std::get<0>(pair);
475 Value operand = std::get<1>(pair);
476 os << emitter.getOrCreateName(iterArg) << " = "
477 << emitter.getOrCreateName(operand) << ";\n";
478 }
479
480 os.unindent() << "}";
481
482 // Copy iterArgs into results after the for loop.
483 for (auto pair : llvm::zip(results, iterArgs)) {
484 OpResult result = std::get<0>(pair);
485 BlockArgument iterArg = std::get<1>(pair);
486 os << "\n"
487 << emitter.getOrCreateName(result) << " = "
488 << emitter.getOrCreateName(iterArg) << ";";
489 }
490
491 return success();
492 }
493
printOperation(CppEmitter & emitter,scf::IfOp ifOp)494 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
495 raw_indented_ostream &os = emitter.ostream();
496
497 if (!emitter.shouldDeclareVariablesAtTop()) {
498 for (OpResult result : ifOp.getResults()) {
499 if (failed(emitter.emitVariableDeclaration(result,
500 /*trailingSemicolon=*/true)))
501 return failure();
502 }
503 }
504
505 os << "if (";
506 if (failed(emitter.emitOperands(*ifOp.getOperation())))
507 return failure();
508 os << ") {\n";
509 os.indent();
510
511 Region &thenRegion = ifOp.getThenRegion();
512 for (Operation &op : thenRegion.getOps()) {
513 // Note: This prints a superfluous semicolon if the terminating yield op has
514 // zero results.
515 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
516 return failure();
517 }
518
519 os.unindent() << "}";
520
521 Region &elseRegion = ifOp.getElseRegion();
522 if (!elseRegion.empty()) {
523 os << " else {\n";
524 os.indent();
525
526 for (Operation &op : elseRegion.getOps()) {
527 // Note: This prints a superfluous semicolon if the terminating yield op
528 // has zero results.
529 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
530 return failure();
531 }
532
533 os.unindent() << "}";
534 }
535
536 return success();
537 }
538
printOperation(CppEmitter & emitter,scf::YieldOp yieldOp)539 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
540 raw_ostream &os = emitter.ostream();
541 Operation &parentOp = *yieldOp.getOperation()->getParentOp();
542
543 if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
544 return yieldOp.emitError("number of operands does not to match the number "
545 "of the parent op's results");
546 }
547
548 if (failed(interleaveWithError(
549 llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
550 [&](auto pair) -> LogicalResult {
551 auto result = std::get<0>(pair);
552 auto operand = std::get<1>(pair);
553 os << emitter.getOrCreateName(result) << " = ";
554
555 if (!emitter.hasValueInScope(operand))
556 return yieldOp.emitError("operand value not in scope");
557 os << emitter.getOrCreateName(operand);
558 return success();
559 },
560 [&]() { os << ";\n"; })))
561 return failure();
562
563 return success();
564 }
565
printOperation(CppEmitter & emitter,func::ReturnOp returnOp)566 static LogicalResult printOperation(CppEmitter &emitter,
567 func::ReturnOp returnOp) {
568 raw_ostream &os = emitter.ostream();
569 os << "return";
570 switch (returnOp.getNumOperands()) {
571 case 0:
572 return success();
573 case 1:
574 os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
575 return success(emitter.hasValueInScope(returnOp.getOperand(0)));
576 default:
577 os << " std::make_tuple(";
578 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
579 return failure();
580 os << ")";
581 return success();
582 }
583 }
584
printOperation(CppEmitter & emitter,ModuleOp moduleOp)585 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
586 CppEmitter::Scope scope(emitter);
587
588 for (Operation &op : moduleOp) {
589 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
590 return failure();
591 }
592 return success();
593 }
594
printOperation(CppEmitter & emitter,func::FuncOp functionOp)595 static LogicalResult printOperation(CppEmitter &emitter,
596 func::FuncOp functionOp) {
597 // We need to declare variables at top if the function has multiple blocks.
598 if (!emitter.shouldDeclareVariablesAtTop() &&
599 functionOp.getBlocks().size() > 1) {
600 return functionOp.emitOpError(
601 "with multiple blocks needs variables declared at top");
602 }
603
604 CppEmitter::Scope scope(emitter);
605 raw_indented_ostream &os = emitter.ostream();
606 if (failed(emitter.emitTypes(functionOp.getLoc(),
607 functionOp.getFunctionType().getResults())))
608 return failure();
609 os << " " << functionOp.getName();
610
611 os << "(";
612 if (failed(interleaveCommaWithError(
613 functionOp.getArguments(), os,
614 [&](BlockArgument arg) -> LogicalResult {
615 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
616 return failure();
617 os << " " << emitter.getOrCreateName(arg);
618 return success();
619 })))
620 return failure();
621 os << ") {\n";
622 os.indent();
623 if (emitter.shouldDeclareVariablesAtTop()) {
624 // Declare all variables that hold op results including those from nested
625 // regions.
626 WalkResult result =
627 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
628 for (OpResult result : op->getResults()) {
629 if (failed(emitter.emitVariableDeclaration(
630 result, /*trailingSemicolon=*/true))) {
631 return WalkResult(
632 op->emitError("unable to declare result variable for op"));
633 }
634 }
635 return WalkResult::advance();
636 });
637 if (result.wasInterrupted())
638 return failure();
639 }
640
641 Region::BlockListType &blocks = functionOp.getBlocks();
642 // Create label names for basic blocks.
643 for (Block &block : blocks) {
644 emitter.getOrCreateName(block);
645 }
646
647 // Declare variables for basic block arguments.
648 for (Block &block : llvm::drop_begin(blocks)) {
649 for (BlockArgument &arg : block.getArguments()) {
650 if (emitter.hasValueInScope(arg))
651 return functionOp.emitOpError(" block argument #")
652 << arg.getArgNumber() << " is out of scope";
653 if (failed(
654 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
655 return failure();
656 }
657 os << " " << emitter.getOrCreateName(arg) << ";\n";
658 }
659 }
660
661 for (Block &block : blocks) {
662 // Only print a label if the block has predecessors.
663 if (!block.hasNoPredecessors()) {
664 if (failed(emitter.emitLabel(block)))
665 return failure();
666 }
667 for (Operation &op : block.getOperations()) {
668 // When generating code for an scf.if or cf.cond_br op no semicolon needs
669 // to be printed after the closing brace.
670 // When generating code for an scf.for op, printing a trailing semicolon
671 // is handled within the printOperation function.
672 bool trailingSemicolon =
673 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
674
675 if (failed(emitter.emitOperation(
676 op, /*trailingSemicolon=*/trailingSemicolon)))
677 return failure();
678 }
679 }
680 os.unindent() << "}\n";
681 return success();
682 }
683
CppEmitter(raw_ostream & os,bool declareVariablesAtTop)684 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
685 : os(os), declareVariablesAtTop(declareVariablesAtTop) {
686 valueInScopeCount.push(0);
687 labelInScopeCount.push(0);
688 }
689
690 /// Return the existing or a new name for a Value.
getOrCreateName(Value val)691 StringRef CppEmitter::getOrCreateName(Value val) {
692 if (!valueMapper.count(val))
693 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
694 return *valueMapper.begin(val);
695 }
696
697 /// Return the existing or a new label for a Block.
getOrCreateName(Block & block)698 StringRef CppEmitter::getOrCreateName(Block &block) {
699 if (!blockMapper.count(&block))
700 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
701 return *blockMapper.begin(&block);
702 }
703
shouldMapToUnsigned(IntegerType::SignednessSemantics val)704 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
705 switch (val) {
706 case IntegerType::Signless:
707 return false;
708 case IntegerType::Signed:
709 return false;
710 case IntegerType::Unsigned:
711 return true;
712 }
713 llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
714 }
715
hasValueInScope(Value val)716 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
717
hasBlockLabel(Block & block)718 bool CppEmitter::hasBlockLabel(Block &block) {
719 return blockMapper.count(&block);
720 }
721
emitAttribute(Location loc,Attribute attr)722 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
723 auto printInt = [&](const APInt &val, bool isUnsigned) {
724 if (val.getBitWidth() == 1) {
725 if (val.getBoolValue())
726 os << "true";
727 else
728 os << "false";
729 } else {
730 SmallString<128> strValue;
731 val.toString(strValue, 10, !isUnsigned, false);
732 os << strValue;
733 }
734 };
735
736 auto printFloat = [&](const APFloat &val) {
737 if (val.isFinite()) {
738 SmallString<128> strValue;
739 // Use default values of toString except don't truncate zeros.
740 val.toString(strValue, 0, 0, false);
741 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
742 case llvm::APFloatBase::S_IEEEsingle:
743 os << "(float)";
744 break;
745 case llvm::APFloatBase::S_IEEEdouble:
746 os << "(double)";
747 break;
748 default:
749 break;
750 };
751 os << strValue;
752 } else if (val.isNaN()) {
753 os << "NAN";
754 } else if (val.isInfinity()) {
755 if (val.isNegative())
756 os << "-";
757 os << "INFINITY";
758 }
759 };
760
761 // Print floating point attributes.
762 if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
763 printFloat(fAttr.getValue());
764 return success();
765 }
766 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
767 os << '{';
768 interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
769 os << '}';
770 return success();
771 }
772
773 // Print integer attributes.
774 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
775 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
776 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
777 return success();
778 }
779 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
780 printInt(iAttr.getValue(), false);
781 return success();
782 }
783 }
784 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
785 if (auto iType = dense.getType()
786 .cast<TensorType>()
787 .getElementType()
788 .dyn_cast<IntegerType>()) {
789 os << '{';
790 interleaveComma(dense, os, [&](const APInt &val) {
791 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
792 });
793 os << '}';
794 return success();
795 }
796 if (auto iType = dense.getType()
797 .cast<TensorType>()
798 .getElementType()
799 .dyn_cast<IndexType>()) {
800 os << '{';
801 interleaveComma(dense, os,
802 [&](const APInt &val) { printInt(val, false); });
803 os << '}';
804 return success();
805 }
806 }
807
808 // Print opaque attributes.
809 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
810 os << oAttr.getValue();
811 return success();
812 }
813
814 // Print symbolic reference attributes.
815 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
816 if (sAttr.getNestedReferences().size() > 1)
817 return emitError(loc, "attribute has more than 1 nested reference");
818 os << sAttr.getRootReference().getValue();
819 return success();
820 }
821
822 // Print type attributes.
823 if (auto type = attr.dyn_cast<TypeAttr>())
824 return emitType(loc, type.getValue());
825
826 return emitError(loc, "cannot emit attribute of type ") << attr.getType();
827 }
828
emitOperands(Operation & op)829 LogicalResult CppEmitter::emitOperands(Operation &op) {
830 auto emitOperandName = [&](Value result) -> LogicalResult {
831 if (!hasValueInScope(result))
832 return op.emitOpError() << "operand value not in scope";
833 os << getOrCreateName(result);
834 return success();
835 };
836 return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
837 }
838
839 LogicalResult
emitOperandsAndAttributes(Operation & op,ArrayRef<StringRef> exclude)840 CppEmitter::emitOperandsAndAttributes(Operation &op,
841 ArrayRef<StringRef> exclude) {
842 if (failed(emitOperands(op)))
843 return failure();
844 // Insert comma in between operands and non-filtered attributes if needed.
845 if (op.getNumOperands() > 0) {
846 for (NamedAttribute attr : op.getAttrs()) {
847 if (!llvm::is_contained(exclude, attr.getName().strref())) {
848 os << ", ";
849 break;
850 }
851 }
852 }
853 // Emit attributes.
854 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
855 if (llvm::is_contained(exclude, attr.getName().strref()))
856 return success();
857 os << "/* " << attr.getName().getValue() << " */";
858 if (failed(emitAttribute(op.getLoc(), attr.getValue())))
859 return failure();
860 return success();
861 };
862 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
863 }
864
emitVariableAssignment(OpResult result)865 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
866 if (!hasValueInScope(result)) {
867 return result.getDefiningOp()->emitOpError(
868 "result variable for the operation has not been declared");
869 }
870 os << getOrCreateName(result) << " = ";
871 return success();
872 }
873
emitVariableDeclaration(OpResult result,bool trailingSemicolon)874 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
875 bool trailingSemicolon) {
876 if (hasValueInScope(result)) {
877 return result.getDefiningOp()->emitError(
878 "result variable for the operation already declared");
879 }
880 if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
881 return failure();
882 os << " " << getOrCreateName(result);
883 if (trailingSemicolon)
884 os << ";\n";
885 return success();
886 }
887
emitAssignPrefix(Operation & op)888 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
889 switch (op.getNumResults()) {
890 case 0:
891 break;
892 case 1: {
893 OpResult result = op.getResult(0);
894 if (shouldDeclareVariablesAtTop()) {
895 if (failed(emitVariableAssignment(result)))
896 return failure();
897 } else {
898 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
899 return failure();
900 os << " = ";
901 }
902 break;
903 }
904 default:
905 if (!shouldDeclareVariablesAtTop()) {
906 for (OpResult result : op.getResults()) {
907 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
908 return failure();
909 }
910 }
911 os << "std::tie(";
912 interleaveComma(op.getResults(), os,
913 [&](Value result) { os << getOrCreateName(result); });
914 os << ") = ";
915 }
916 return success();
917 }
918
emitLabel(Block & block)919 LogicalResult CppEmitter::emitLabel(Block &block) {
920 if (!hasBlockLabel(block))
921 return block.getParentOp()->emitError("label for block not found");
922 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
923 // label instead of using `getOStream`.
924 os.getOStream() << getOrCreateName(block) << ":\n";
925 return success();
926 }
927
emitOperation(Operation & op,bool trailingSemicolon)928 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
929 LogicalResult status =
930 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
931 // Builtin ops.
932 .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
933 // CF ops.
934 .Case<cf::BranchOp, cf::CondBranchOp>(
935 [&](auto op) { return printOperation(*this, op); })
936 // EmitC ops.
937 .Case<emitc::ApplyOp, emitc::CallOp, emitc::CastOp, emitc::ConstantOp,
938 emitc::IncludeOp, emitc::VariableOp>(
939 [&](auto op) { return printOperation(*this, op); })
940 // Func ops.
941 .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
942 [&](auto op) { return printOperation(*this, op); })
943 // SCF ops.
944 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
945 [&](auto op) { return printOperation(*this, op); })
946 // Arithmetic ops.
947 .Case<arith::ConstantOp>(
948 [&](auto op) { return printOperation(*this, op); })
949 .Default([&](Operation *) {
950 return op.emitOpError("unable to find printer for op");
951 });
952
953 if (failed(status))
954 return failure();
955 os << (trailingSemicolon ? ";\n" : "\n");
956 return success();
957 }
958
emitType(Location loc,Type type)959 LogicalResult CppEmitter::emitType(Location loc, Type type) {
960 if (auto iType = type.dyn_cast<IntegerType>()) {
961 switch (iType.getWidth()) {
962 case 1:
963 return (os << "bool"), success();
964 case 8:
965 case 16:
966 case 32:
967 case 64:
968 if (shouldMapToUnsigned(iType.getSignedness()))
969 return (os << "uint" << iType.getWidth() << "_t"), success();
970 else
971 return (os << "int" << iType.getWidth() << "_t"), success();
972 default:
973 return emitError(loc, "cannot emit integer type ") << type;
974 }
975 }
976 if (auto fType = type.dyn_cast<FloatType>()) {
977 switch (fType.getWidth()) {
978 case 32:
979 return (os << "float"), success();
980 case 64:
981 return (os << "double"), success();
982 default:
983 return emitError(loc, "cannot emit float type ") << type;
984 }
985 }
986 if (auto iType = type.dyn_cast<IndexType>())
987 return (os << "size_t"), success();
988 if (auto tType = type.dyn_cast<TensorType>()) {
989 if (!tType.hasRank())
990 return emitError(loc, "cannot emit unranked tensor type");
991 if (!tType.hasStaticShape())
992 return emitError(loc, "cannot emit tensor type with non static shape");
993 os << "Tensor<";
994 if (failed(emitType(loc, tType.getElementType())))
995 return failure();
996 auto shape = tType.getShape();
997 for (auto dimSize : shape) {
998 os << ", ";
999 os << dimSize;
1000 }
1001 os << ">";
1002 return success();
1003 }
1004 if (auto tType = type.dyn_cast<TupleType>())
1005 return emitTupleType(loc, tType.getTypes());
1006 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
1007 os << oType.getValue();
1008 return success();
1009 }
1010 if (auto pType = type.dyn_cast<emitc::PointerType>()) {
1011 if (failed(emitType(loc, pType.getPointee())))
1012 return failure();
1013 os << "*";
1014 return success();
1015 }
1016 return emitError(loc, "cannot emit type ") << type;
1017 }
1018
emitTypes(Location loc,ArrayRef<Type> types)1019 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1020 switch (types.size()) {
1021 case 0:
1022 os << "void";
1023 return success();
1024 case 1:
1025 return emitType(loc, types.front());
1026 default:
1027 return emitTupleType(loc, types);
1028 }
1029 }
1030
emitTupleType(Location loc,ArrayRef<Type> types)1031 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1032 os << "std::tuple<";
1033 if (failed(interleaveCommaWithError(
1034 types, os, [&](Type type) { return emitType(loc, type); })))
1035 return failure();
1036 os << ">";
1037 return success();
1038 }
1039
translateToCpp(Operation * op,raw_ostream & os,bool declareVariablesAtTop)1040 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1041 bool declareVariablesAtTop) {
1042 CppEmitter emitter(os, declareVariablesAtTop);
1043 return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
1044 }
1045