1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 #include "mlir/Dialect/EmitC/IR/EmitC.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/Support/IndentedOstream.h" 18 #include "mlir/Target/Cpp/CppEmitter.h" 19 #include "llvm/ADT/DenseMap.h" 20 #include "llvm/ADT/StringExtras.h" 21 #include "llvm/ADT/StringMap.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <utility> 26 27 #define DEBUG_TYPE "translate-to-cpp" 28 29 using namespace mlir; 30 using namespace mlir::emitc; 31 using llvm::formatv; 32 33 /// Convenience functions to produce interleaved output with functions returning 34 /// a LogicalResult. This is different than those in STLExtras as functions used 35 /// on each element doesn't return a string. 36 template <typename ForwardIterator, typename UnaryFunctor, 37 typename NullaryFunctor> 38 inline LogicalResult 39 interleaveWithError(ForwardIterator begin, ForwardIterator end, 40 UnaryFunctor eachFn, NullaryFunctor betweenFn) { 41 if (begin == end) 42 return success(); 43 if (failed(eachFn(*begin))) 44 return failure(); 45 ++begin; 46 for (; begin != end; ++begin) { 47 betweenFn(); 48 if (failed(eachFn(*begin))) 49 return failure(); 50 } 51 return success(); 52 } 53 54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor> 55 inline LogicalResult interleaveWithError(const Container &c, 56 UnaryFunctor eachFn, 57 NullaryFunctor betweenFn) { 58 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); 59 } 60 61 template <typename Container, typename UnaryFunctor> 62 inline LogicalResult interleaveCommaWithError(const Container &c, 63 raw_ostream &os, 64 UnaryFunctor eachFn) { 65 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); 66 } 67 68 namespace { 69 /// Emitter that uses dialect specific emitters to emit C++ code. 70 struct CppEmitter { 71 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); 72 73 /// Emits attribute or returns failure. 74 LogicalResult emitAttribute(Location loc, Attribute attr); 75 76 /// Emits operation 'op' with/without training semicolon or returns failure. 77 LogicalResult emitOperation(Operation &op, bool trailingSemicolon); 78 79 /// Emits type 'type' or returns failure. 80 LogicalResult emitType(Location loc, Type type); 81 82 /// Emits array of types as a std::tuple of the emitted types. 83 /// - emits void for an empty array; 84 /// - emits the type of the only element for arrays of size one; 85 /// - emits a std::tuple otherwise; 86 LogicalResult emitTypes(Location loc, ArrayRef<Type> types); 87 88 /// Emits array of types as a std::tuple of the emitted types independently of 89 /// the array size. 90 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); 91 92 /// Emits an assignment for a variable which has been declared previously. 93 LogicalResult emitVariableAssignment(OpResult result); 94 95 /// Emits a variable declaration for a result of an operation. 96 LogicalResult emitVariableDeclaration(OpResult result, 97 bool trailingSemicolon); 98 99 /// Emits the variable declaration and assignment prefix for 'op'. 100 /// - emits separate variable followed by std::tie for multi-valued operation; 101 /// - emits single type followed by variable for single result; 102 /// - emits nothing if no value produced by op; 103 /// Emits final '=' operator where a type is produced. Returns failure if 104 /// any result type could not be converted. 105 LogicalResult emitAssignPrefix(Operation &op); 106 107 /// Emits a label for the block. 108 LogicalResult emitLabel(Block &block); 109 110 /// Emits the operands and atttributes of the operation. All operands are 111 /// emitted first and then all attributes in alphabetical order. 112 LogicalResult emitOperandsAndAttributes(Operation &op, 113 ArrayRef<StringRef> exclude = {}); 114 115 /// Emits the operands of the operation. All operands are emitted in order. 116 LogicalResult emitOperands(Operation &op); 117 118 /// Return the existing or a new name for a Value. 119 StringRef getOrCreateName(Value val); 120 121 /// Return the existing or a new label of a Block. 122 StringRef getOrCreateName(Block &block); 123 124 /// Whether to map an mlir integer to a unsigned integer in C++. 125 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); 126 127 /// RAII helper function to manage entering/exiting C++ scopes. 128 struct Scope { 129 Scope(CppEmitter &emitter) 130 : valueMapperScope(emitter.valueMapper), 131 blockMapperScope(emitter.blockMapper), emitter(emitter) { 132 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); 133 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); 134 } 135 ~Scope() { 136 emitter.valueInScopeCount.pop(); 137 emitter.labelInScopeCount.pop(); 138 } 139 140 private: 141 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; 142 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; 143 CppEmitter &emitter; 144 }; 145 146 /// Returns wether the Value is assigned to a C++ variable in the scope. 147 bool hasValueInScope(Value val); 148 149 // Returns whether a label is assigned to the block. 150 bool hasBlockLabel(Block &block); 151 152 /// Returns the output stream. 153 raw_indented_ostream &ostream() { return os; }; 154 155 /// Returns if all variables for op results and basic block arguments need to 156 /// be declared at the beginning of a function. 157 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; 158 159 private: 160 using ValueMapper = llvm::ScopedHashTable<Value, std::string>; 161 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; 162 163 /// Output stream to emit to. 164 raw_indented_ostream os; 165 166 /// Boolean to enforce that all variables for op results and block 167 /// arguments are declared at the beginning of the function. This also 168 /// includes results from ops located in nested regions. 169 bool declareVariablesAtTop; 170 171 /// Map from value to name of C++ variable that contain the name. 172 ValueMapper valueMapper; 173 174 /// Map from block to name of C++ label. 175 BlockMapper blockMapper; 176 177 /// The number of values in the current scope. This is used to declare the 178 /// names of values in a scope. 179 std::stack<int64_t> valueInScopeCount; 180 std::stack<int64_t> labelInScopeCount; 181 }; 182 } // namespace 183 184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, 185 Attribute value) { 186 OpResult result = operation->getResult(0); 187 188 // Only emit an assignment as the variable was already declared when printing 189 // the FuncOp. 190 if (emitter.shouldDeclareVariablesAtTop()) { 191 // Skip the assignment if the emitc.constant has no value. 192 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 193 if (oAttr.getValue().empty()) 194 return success(); 195 } 196 197 if (failed(emitter.emitVariableAssignment(result))) 198 return failure(); 199 return emitter.emitAttribute(operation->getLoc(), value); 200 } 201 202 // Emit a variable declaration for an emitc.constant op without value. 203 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 204 if (oAttr.getValue().empty()) 205 // The semicolon gets printed by the emitOperation function. 206 return emitter.emitVariableDeclaration(result, 207 /*trailingSemicolon=*/false); 208 } 209 210 // Emit a variable declaration. 211 if (failed(emitter.emitAssignPrefix(*operation))) 212 return failure(); 213 return emitter.emitAttribute(operation->getLoc(), value); 214 } 215 216 static LogicalResult printOperation(CppEmitter &emitter, 217 emitc::ConstantOp constantOp) { 218 Operation *operation = constantOp.getOperation(); 219 Attribute value = constantOp.value(); 220 221 return printConstantOp(emitter, operation, value); 222 } 223 224 static LogicalResult printOperation(CppEmitter &emitter, 225 emitc::VariableOp variableOp) { 226 Operation *operation = variableOp.getOperation(); 227 Attribute value = variableOp.value(); 228 229 return printConstantOp(emitter, operation, value); 230 } 231 232 static LogicalResult printOperation(CppEmitter &emitter, 233 arith::ConstantOp constantOp) { 234 Operation *operation = constantOp.getOperation(); 235 Attribute value = constantOp.getValue(); 236 237 return printConstantOp(emitter, operation, value); 238 } 239 240 static LogicalResult printOperation(CppEmitter &emitter, 241 func::ConstantOp constantOp) { 242 Operation *operation = constantOp.getOperation(); 243 Attribute value = constantOp.getValueAttr(); 244 245 return printConstantOp(emitter, operation, value); 246 } 247 248 static LogicalResult printOperation(CppEmitter &emitter, 249 cf::BranchOp branchOp) { 250 raw_ostream &os = emitter.ostream(); 251 Block &successor = *branchOp.getSuccessor(); 252 253 for (auto pair : 254 llvm::zip(branchOp.getOperands(), successor.getArguments())) { 255 Value &operand = std::get<0>(pair); 256 BlockArgument &argument = std::get<1>(pair); 257 os << emitter.getOrCreateName(argument) << " = " 258 << emitter.getOrCreateName(operand) << ";\n"; 259 } 260 261 os << "goto "; 262 if (!(emitter.hasBlockLabel(successor))) 263 return branchOp.emitOpError("unable to find label for successor block"); 264 os << emitter.getOrCreateName(successor); 265 return success(); 266 } 267 268 static LogicalResult printOperation(CppEmitter &emitter, 269 cf::CondBranchOp condBranchOp) { 270 raw_indented_ostream &os = emitter.ostream(); 271 Block &trueSuccessor = *condBranchOp.getTrueDest(); 272 Block &falseSuccessor = *condBranchOp.getFalseDest(); 273 274 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) 275 << ") {\n"; 276 277 os.indent(); 278 279 // If condition is true. 280 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), 281 trueSuccessor.getArguments())) { 282 Value &operand = std::get<0>(pair); 283 BlockArgument &argument = std::get<1>(pair); 284 os << emitter.getOrCreateName(argument) << " = " 285 << emitter.getOrCreateName(operand) << ";\n"; 286 } 287 288 os << "goto "; 289 if (!(emitter.hasBlockLabel(trueSuccessor))) { 290 return condBranchOp.emitOpError("unable to find label for successor block"); 291 } 292 os << emitter.getOrCreateName(trueSuccessor) << ";\n"; 293 os.unindent() << "} else {\n"; 294 os.indent(); 295 // If condition is false. 296 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), 297 falseSuccessor.getArguments())) { 298 Value &operand = std::get<0>(pair); 299 BlockArgument &argument = std::get<1>(pair); 300 os << emitter.getOrCreateName(argument) << " = " 301 << emitter.getOrCreateName(operand) << ";\n"; 302 } 303 304 os << "goto "; 305 if (!(emitter.hasBlockLabel(falseSuccessor))) { 306 return condBranchOp.emitOpError() 307 << "unable to find label for successor block"; 308 } 309 os << emitter.getOrCreateName(falseSuccessor) << ";\n"; 310 os.unindent() << "}"; 311 return success(); 312 } 313 314 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { 315 if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) 316 return failure(); 317 318 raw_ostream &os = emitter.ostream(); 319 os << callOp.getCallee() << "("; 320 if (failed(emitter.emitOperands(*callOp.getOperation()))) 321 return failure(); 322 os << ")"; 323 return success(); 324 } 325 326 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { 327 raw_ostream &os = emitter.ostream(); 328 Operation &op = *callOp.getOperation(); 329 330 if (failed(emitter.emitAssignPrefix(op))) 331 return failure(); 332 os << callOp.callee(); 333 334 auto emitArgs = [&](Attribute attr) -> LogicalResult { 335 if (auto t = attr.dyn_cast<IntegerAttr>()) { 336 // Index attributes are treated specially as operand index. 337 if (t.getType().isIndex()) { 338 int64_t idx = t.getInt(); 339 if ((idx < 0) || (idx >= op.getNumOperands())) 340 return op.emitOpError("invalid operand index"); 341 if (!emitter.hasValueInScope(op.getOperand(idx))) 342 return op.emitOpError("operand ") 343 << idx << "'s value not defined in scope"; 344 os << emitter.getOrCreateName(op.getOperand(idx)); 345 return success(); 346 } 347 } 348 if (failed(emitter.emitAttribute(op.getLoc(), attr))) 349 return failure(); 350 351 return success(); 352 }; 353 354 if (callOp.template_args()) { 355 os << "<"; 356 if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs))) 357 return failure(); 358 os << ">"; 359 } 360 361 os << "("; 362 363 LogicalResult emittedArgs = 364 callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs) 365 : emitter.emitOperands(op); 366 if (failed(emittedArgs)) 367 return failure(); 368 os << ")"; 369 return success(); 370 } 371 372 static LogicalResult printOperation(CppEmitter &emitter, 373 emitc::ApplyOp applyOp) { 374 raw_ostream &os = emitter.ostream(); 375 Operation &op = *applyOp.getOperation(); 376 377 if (failed(emitter.emitAssignPrefix(op))) 378 return failure(); 379 os << applyOp.applicableOperator(); 380 os << emitter.getOrCreateName(applyOp.getOperand()); 381 382 return success(); 383 } 384 385 static LogicalResult printOperation(CppEmitter &emitter, 386 emitc::IncludeOp includeOp) { 387 raw_ostream &os = emitter.ostream(); 388 389 os << "#include "; 390 if (includeOp.is_standard_include()) 391 os << "<" << includeOp.include() << ">"; 392 else 393 os << "\"" << includeOp.include() << "\""; 394 395 return success(); 396 } 397 398 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { 399 400 raw_indented_ostream &os = emitter.ostream(); 401 402 OperandRange operands = forOp.getIterOperands(); 403 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); 404 Operation::result_range results = forOp.getResults(); 405 406 if (!emitter.shouldDeclareVariablesAtTop()) { 407 for (OpResult result : results) { 408 if (failed(emitter.emitVariableDeclaration(result, 409 /*trailingSemicolon=*/true))) 410 return failure(); 411 } 412 } 413 414 for (auto pair : llvm::zip(iterArgs, operands)) { 415 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) 416 return failure(); 417 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; 418 os << emitter.getOrCreateName(std::get<1>(pair)) << ";"; 419 os << "\n"; 420 } 421 422 os << "for ("; 423 if (failed( 424 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) 425 return failure(); 426 os << " "; 427 os << emitter.getOrCreateName(forOp.getInductionVar()); 428 os << " = "; 429 os << emitter.getOrCreateName(forOp.getLowerBound()); 430 os << "; "; 431 os << emitter.getOrCreateName(forOp.getInductionVar()); 432 os << " < "; 433 os << emitter.getOrCreateName(forOp.getUpperBound()); 434 os << "; "; 435 os << emitter.getOrCreateName(forOp.getInductionVar()); 436 os << " += "; 437 os << emitter.getOrCreateName(forOp.getStep()); 438 os << ") {\n"; 439 os.indent(); 440 441 Region &forRegion = forOp.getRegion(); 442 auto regionOps = forRegion.getOps(); 443 444 // We skip the trailing yield op because this updates the result variables 445 // of the for op in the generated code. Instead we update the iterArgs at 446 // the end of a loop iteration and set the result variables after the for 447 // loop. 448 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { 449 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 450 return failure(); 451 } 452 453 Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); 454 // Copy yield operands into iterArgs at the end of a loop iteration. 455 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { 456 BlockArgument iterArg = std::get<0>(pair); 457 Value operand = std::get<1>(pair); 458 os << emitter.getOrCreateName(iterArg) << " = " 459 << emitter.getOrCreateName(operand) << ";\n"; 460 } 461 462 os.unindent() << "}"; 463 464 // Copy iterArgs into results after the for loop. 465 for (auto pair : llvm::zip(results, iterArgs)) { 466 OpResult result = std::get<0>(pair); 467 BlockArgument iterArg = std::get<1>(pair); 468 os << "\n" 469 << emitter.getOrCreateName(result) << " = " 470 << emitter.getOrCreateName(iterArg) << ";"; 471 } 472 473 return success(); 474 } 475 476 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { 477 raw_indented_ostream &os = emitter.ostream(); 478 479 if (!emitter.shouldDeclareVariablesAtTop()) { 480 for (OpResult result : ifOp.getResults()) { 481 if (failed(emitter.emitVariableDeclaration(result, 482 /*trailingSemicolon=*/true))) 483 return failure(); 484 } 485 } 486 487 os << "if ("; 488 if (failed(emitter.emitOperands(*ifOp.getOperation()))) 489 return failure(); 490 os << ") {\n"; 491 os.indent(); 492 493 Region &thenRegion = ifOp.getThenRegion(); 494 for (Operation &op : thenRegion.getOps()) { 495 // Note: This prints a superfluous semicolon if the terminating yield op has 496 // zero results. 497 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 498 return failure(); 499 } 500 501 os.unindent() << "}"; 502 503 Region &elseRegion = ifOp.getElseRegion(); 504 if (!elseRegion.empty()) { 505 os << " else {\n"; 506 os.indent(); 507 508 for (Operation &op : elseRegion.getOps()) { 509 // Note: This prints a superfluous semicolon if the terminating yield op 510 // has zero results. 511 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 512 return failure(); 513 } 514 515 os.unindent() << "}"; 516 } 517 518 return success(); 519 } 520 521 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { 522 raw_ostream &os = emitter.ostream(); 523 Operation &parentOp = *yieldOp.getOperation()->getParentOp(); 524 525 if (yieldOp.getNumOperands() != parentOp.getNumResults()) { 526 return yieldOp.emitError("number of operands does not to match the number " 527 "of the parent op's results"); 528 } 529 530 if (failed(interleaveWithError( 531 llvm::zip(parentOp.getResults(), yieldOp.getOperands()), 532 [&](auto pair) -> LogicalResult { 533 auto result = std::get<0>(pair); 534 auto operand = std::get<1>(pair); 535 os << emitter.getOrCreateName(result) << " = "; 536 537 if (!emitter.hasValueInScope(operand)) 538 return yieldOp.emitError("operand value not in scope"); 539 os << emitter.getOrCreateName(operand); 540 return success(); 541 }, 542 [&]() { os << ";\n"; }))) 543 return failure(); 544 545 return success(); 546 } 547 548 static LogicalResult printOperation(CppEmitter &emitter, 549 func::ReturnOp returnOp) { 550 raw_ostream &os = emitter.ostream(); 551 os << "return"; 552 switch (returnOp.getNumOperands()) { 553 case 0: 554 return success(); 555 case 1: 556 os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); 557 return success(emitter.hasValueInScope(returnOp.getOperand(0))); 558 default: 559 os << " std::make_tuple("; 560 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) 561 return failure(); 562 os << ")"; 563 return success(); 564 } 565 } 566 567 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { 568 CppEmitter::Scope scope(emitter); 569 570 for (Operation &op : moduleOp) { 571 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) 572 return failure(); 573 } 574 return success(); 575 } 576 577 static LogicalResult printOperation(CppEmitter &emitter, 578 func::FuncOp functionOp) { 579 // We need to declare variables at top if the function has multiple blocks. 580 if (!emitter.shouldDeclareVariablesAtTop() && 581 functionOp.getBlocks().size() > 1) { 582 return functionOp.emitOpError( 583 "with multiple blocks needs variables declared at top"); 584 } 585 586 CppEmitter::Scope scope(emitter); 587 raw_indented_ostream &os = emitter.ostream(); 588 if (failed(emitter.emitTypes(functionOp.getLoc(), 589 functionOp.getFunctionType().getResults()))) 590 return failure(); 591 os << " " << functionOp.getName(); 592 593 os << "("; 594 if (failed(interleaveCommaWithError( 595 functionOp.getArguments(), os, 596 [&](BlockArgument arg) -> LogicalResult { 597 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) 598 return failure(); 599 os << " " << emitter.getOrCreateName(arg); 600 return success(); 601 }))) 602 return failure(); 603 os << ") {\n"; 604 os.indent(); 605 if (emitter.shouldDeclareVariablesAtTop()) { 606 // Declare all variables that hold op results including those from nested 607 // regions. 608 WalkResult result = 609 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { 610 for (OpResult result : op->getResults()) { 611 if (failed(emitter.emitVariableDeclaration( 612 result, /*trailingSemicolon=*/true))) { 613 return WalkResult( 614 op->emitError("unable to declare result variable for op")); 615 } 616 } 617 return WalkResult::advance(); 618 }); 619 if (result.wasInterrupted()) 620 return failure(); 621 } 622 623 Region::BlockListType &blocks = functionOp.getBlocks(); 624 // Create label names for basic blocks. 625 for (Block &block : blocks) { 626 emitter.getOrCreateName(block); 627 } 628 629 // Declare variables for basic block arguments. 630 for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { 631 Block &block = *it; 632 for (BlockArgument &arg : block.getArguments()) { 633 if (emitter.hasValueInScope(arg)) 634 return functionOp.emitOpError(" block argument #") 635 << arg.getArgNumber() << " is out of scope"; 636 if (failed( 637 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { 638 return failure(); 639 } 640 os << " " << emitter.getOrCreateName(arg) << ";\n"; 641 } 642 } 643 644 for (Block &block : blocks) { 645 // Only print a label if the block has predecessors. 646 if (!block.hasNoPredecessors()) { 647 if (failed(emitter.emitLabel(block))) 648 return failure(); 649 } 650 for (Operation &op : block.getOperations()) { 651 // When generating code for an scf.if or cf.cond_br op no semicolon needs 652 // to be printed after the closing brace. 653 // When generating code for an scf.for op, printing a trailing semicolon 654 // is handled within the printOperation function. 655 bool trailingSemicolon = 656 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op); 657 658 if (failed(emitter.emitOperation( 659 op, /*trailingSemicolon=*/trailingSemicolon))) 660 return failure(); 661 } 662 } 663 os.unindent() << "}\n"; 664 return success(); 665 } 666 667 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) 668 : os(os), declareVariablesAtTop(declareVariablesAtTop) { 669 valueInScopeCount.push(0); 670 labelInScopeCount.push(0); 671 } 672 673 /// Return the existing or a new name for a Value. 674 StringRef CppEmitter::getOrCreateName(Value val) { 675 if (!valueMapper.count(val)) 676 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); 677 return *valueMapper.begin(val); 678 } 679 680 /// Return the existing or a new label for a Block. 681 StringRef CppEmitter::getOrCreateName(Block &block) { 682 if (!blockMapper.count(&block)) 683 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); 684 return *blockMapper.begin(&block); 685 } 686 687 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { 688 switch (val) { 689 case IntegerType::Signless: 690 return false; 691 case IntegerType::Signed: 692 return false; 693 case IntegerType::Unsigned: 694 return true; 695 } 696 llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); 697 } 698 699 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } 700 701 bool CppEmitter::hasBlockLabel(Block &block) { 702 return blockMapper.count(&block); 703 } 704 705 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { 706 auto printInt = [&](const APInt &val, bool isUnsigned) { 707 if (val.getBitWidth() == 1) { 708 if (val.getBoolValue()) 709 os << "true"; 710 else 711 os << "false"; 712 } else { 713 SmallString<128> strValue; 714 val.toString(strValue, 10, !isUnsigned, false); 715 os << strValue; 716 } 717 }; 718 719 auto printFloat = [&](const APFloat &val) { 720 if (val.isFinite()) { 721 SmallString<128> strValue; 722 // Use default values of toString except don't truncate zeros. 723 val.toString(strValue, 0, 0, false); 724 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { 725 case llvm::APFloatBase::S_IEEEsingle: 726 os << "(float)"; 727 break; 728 case llvm::APFloatBase::S_IEEEdouble: 729 os << "(double)"; 730 break; 731 default: 732 break; 733 }; 734 os << strValue; 735 } else if (val.isNaN()) { 736 os << "NAN"; 737 } else if (val.isInfinity()) { 738 if (val.isNegative()) 739 os << "-"; 740 os << "INFINITY"; 741 } 742 }; 743 744 // Print floating point attributes. 745 if (auto fAttr = attr.dyn_cast<FloatAttr>()) { 746 printFloat(fAttr.getValue()); 747 return success(); 748 } 749 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) { 750 os << '{'; 751 interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); 752 os << '}'; 753 return success(); 754 } 755 756 // Print integer attributes. 757 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) { 758 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) { 759 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); 760 return success(); 761 } 762 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) { 763 printInt(iAttr.getValue(), false); 764 return success(); 765 } 766 } 767 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) { 768 if (auto iType = dense.getType() 769 .cast<TensorType>() 770 .getElementType() 771 .dyn_cast<IntegerType>()) { 772 os << '{'; 773 interleaveComma(dense, os, [&](const APInt &val) { 774 printInt(val, shouldMapToUnsigned(iType.getSignedness())); 775 }); 776 os << '}'; 777 return success(); 778 } 779 if (auto iType = dense.getType() 780 .cast<TensorType>() 781 .getElementType() 782 .dyn_cast<IndexType>()) { 783 os << '{'; 784 interleaveComma(dense, os, 785 [&](const APInt &val) { printInt(val, false); }); 786 os << '}'; 787 return success(); 788 } 789 } 790 791 // Print opaque attributes. 792 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) { 793 os << oAttr.getValue(); 794 return success(); 795 } 796 797 // Print symbolic reference attributes. 798 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) { 799 if (sAttr.getNestedReferences().size() > 1) 800 return emitError(loc, "attribute has more than 1 nested reference"); 801 os << sAttr.getRootReference().getValue(); 802 return success(); 803 } 804 805 // Print type attributes. 806 if (auto type = attr.dyn_cast<TypeAttr>()) 807 return emitType(loc, type.getValue()); 808 809 return emitError(loc, "cannot emit attribute of type ") << attr.getType(); 810 } 811 812 LogicalResult CppEmitter::emitOperands(Operation &op) { 813 auto emitOperandName = [&](Value result) -> LogicalResult { 814 if (!hasValueInScope(result)) 815 return op.emitOpError() << "operand value not in scope"; 816 os << getOrCreateName(result); 817 return success(); 818 }; 819 return interleaveCommaWithError(op.getOperands(), os, emitOperandName); 820 } 821 822 LogicalResult 823 CppEmitter::emitOperandsAndAttributes(Operation &op, 824 ArrayRef<StringRef> exclude) { 825 if (failed(emitOperands(op))) 826 return failure(); 827 // Insert comma in between operands and non-filtered attributes if needed. 828 if (op.getNumOperands() > 0) { 829 for (NamedAttribute attr : op.getAttrs()) { 830 if (!llvm::is_contained(exclude, attr.getName().strref())) { 831 os << ", "; 832 break; 833 } 834 } 835 } 836 // Emit attributes. 837 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { 838 if (llvm::is_contained(exclude, attr.getName().strref())) 839 return success(); 840 os << "/* " << attr.getName().getValue() << " */"; 841 if (failed(emitAttribute(op.getLoc(), attr.getValue()))) 842 return failure(); 843 return success(); 844 }; 845 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); 846 } 847 848 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { 849 if (!hasValueInScope(result)) { 850 return result.getDefiningOp()->emitOpError( 851 "result variable for the operation has not been declared"); 852 } 853 os << getOrCreateName(result) << " = "; 854 return success(); 855 } 856 857 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, 858 bool trailingSemicolon) { 859 if (hasValueInScope(result)) { 860 return result.getDefiningOp()->emitError( 861 "result variable for the operation already declared"); 862 } 863 if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) 864 return failure(); 865 os << " " << getOrCreateName(result); 866 if (trailingSemicolon) 867 os << ";\n"; 868 return success(); 869 } 870 871 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { 872 switch (op.getNumResults()) { 873 case 0: 874 break; 875 case 1: { 876 OpResult result = op.getResult(0); 877 if (shouldDeclareVariablesAtTop()) { 878 if (failed(emitVariableAssignment(result))) 879 return failure(); 880 } else { 881 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) 882 return failure(); 883 os << " = "; 884 } 885 break; 886 } 887 default: 888 if (!shouldDeclareVariablesAtTop()) { 889 for (OpResult result : op.getResults()) { 890 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) 891 return failure(); 892 } 893 } 894 os << "std::tie("; 895 interleaveComma(op.getResults(), os, 896 [&](Value result) { os << getOrCreateName(result); }); 897 os << ") = "; 898 } 899 return success(); 900 } 901 902 LogicalResult CppEmitter::emitLabel(Block &block) { 903 if (!hasBlockLabel(block)) 904 return block.getParentOp()->emitError("label for block not found"); 905 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block 906 // label instead of using `getOStream`. 907 os.getOStream() << getOrCreateName(block) << ":\n"; 908 return success(); 909 } 910 911 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { 912 LogicalResult status = 913 llvm::TypeSwitch<Operation *, LogicalResult>(&op) 914 // Builtin ops. 915 .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); }) 916 // CF ops. 917 .Case<cf::BranchOp, cf::CondBranchOp>( 918 [&](auto op) { return printOperation(*this, op); }) 919 // EmitC ops. 920 .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp, 921 emitc::IncludeOp, emitc::VariableOp>( 922 [&](auto op) { return printOperation(*this, op); }) 923 // Func ops. 924 .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>( 925 [&](auto op) { return printOperation(*this, op); }) 926 // SCF ops. 927 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>( 928 [&](auto op) { return printOperation(*this, op); }) 929 // Arithmetic ops. 930 .Case<arith::ConstantOp>( 931 [&](auto op) { return printOperation(*this, op); }) 932 .Default([&](Operation *) { 933 return op.emitOpError("unable to find printer for op"); 934 }); 935 936 if (failed(status)) 937 return failure(); 938 os << (trailingSemicolon ? ";\n" : "\n"); 939 return success(); 940 } 941 942 LogicalResult CppEmitter::emitType(Location loc, Type type) { 943 if (auto iType = type.dyn_cast<IntegerType>()) { 944 switch (iType.getWidth()) { 945 case 1: 946 return (os << "bool"), success(); 947 case 8: 948 case 16: 949 case 32: 950 case 64: 951 if (shouldMapToUnsigned(iType.getSignedness())) 952 return (os << "uint" << iType.getWidth() << "_t"), success(); 953 else 954 return (os << "int" << iType.getWidth() << "_t"), success(); 955 default: 956 return emitError(loc, "cannot emit integer type ") << type; 957 } 958 } 959 if (auto fType = type.dyn_cast<FloatType>()) { 960 switch (fType.getWidth()) { 961 case 32: 962 return (os << "float"), success(); 963 case 64: 964 return (os << "double"), success(); 965 default: 966 return emitError(loc, "cannot emit float type ") << type; 967 } 968 } 969 if (auto iType = type.dyn_cast<IndexType>()) 970 return (os << "size_t"), success(); 971 if (auto tType = type.dyn_cast<TensorType>()) { 972 if (!tType.hasRank()) 973 return emitError(loc, "cannot emit unranked tensor type"); 974 if (!tType.hasStaticShape()) 975 return emitError(loc, "cannot emit tensor type with non static shape"); 976 os << "Tensor<"; 977 if (failed(emitType(loc, tType.getElementType()))) 978 return failure(); 979 auto shape = tType.getShape(); 980 for (auto dimSize : shape) { 981 os << ", "; 982 os << dimSize; 983 } 984 os << ">"; 985 return success(); 986 } 987 if (auto tType = type.dyn_cast<TupleType>()) 988 return emitTupleType(loc, tType.getTypes()); 989 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) { 990 os << oType.getValue(); 991 return success(); 992 } 993 if (auto pType = type.dyn_cast<emitc::PointerType>()) { 994 if (failed(emitType(loc, pType.getPointee()))) 995 return failure(); 996 os << "*"; 997 return success(); 998 } 999 return emitError(loc, "cannot emit type ") << type; 1000 } 1001 1002 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { 1003 switch (types.size()) { 1004 case 0: 1005 os << "void"; 1006 return success(); 1007 case 1: 1008 return emitType(loc, types.front()); 1009 default: 1010 return emitTupleType(loc, types); 1011 } 1012 } 1013 1014 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { 1015 os << "std::tuple<"; 1016 if (failed(interleaveCommaWithError( 1017 types, os, [&](Type type) { return emitType(loc, type); }))) 1018 return failure(); 1019 os << ">"; 1020 return success(); 1021 } 1022 1023 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, 1024 bool declareVariablesAtTop) { 1025 CppEmitter emitter(os, declareVariablesAtTop); 1026 return emitter.emitOperation(*op, /*trailingSemicolon=*/false); 1027 } 1028