1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 11 #include "mlir/Dialect/EmitC/IR/EmitC.h" 12 #include "mlir/Dialect/Func/IR/FuncOps.h" 13 #include "mlir/Dialect/SCF/IR/SCF.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Support/IndentedOstream.h" 19 #include "mlir/Target/Cpp/CppEmitter.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringMap.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatVariadic.h" 26 #include <utility> 27 28 #define DEBUG_TYPE "translate-to-cpp" 29 30 using namespace mlir; 31 using namespace mlir::emitc; 32 using llvm::formatv; 33 34 /// Convenience functions to produce interleaved output with functions returning 35 /// a LogicalResult. This is different than those in STLExtras as functions used 36 /// on each element doesn't return a string. 37 template <typename ForwardIterator, typename UnaryFunctor, 38 typename NullaryFunctor> 39 inline LogicalResult 40 interleaveWithError(ForwardIterator begin, ForwardIterator end, 41 UnaryFunctor eachFn, NullaryFunctor betweenFn) { 42 if (begin == end) 43 return success(); 44 if (failed(eachFn(*begin))) 45 return failure(); 46 ++begin; 47 for (; begin != end; ++begin) { 48 betweenFn(); 49 if (failed(eachFn(*begin))) 50 return failure(); 51 } 52 return success(); 53 } 54 55 template <typename Container, typename UnaryFunctor, typename NullaryFunctor> 56 inline LogicalResult interleaveWithError(const Container &c, 57 UnaryFunctor eachFn, 58 NullaryFunctor betweenFn) { 59 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); 60 } 61 62 template <typename Container, typename UnaryFunctor> 63 inline LogicalResult interleaveCommaWithError(const Container &c, 64 raw_ostream &os, 65 UnaryFunctor eachFn) { 66 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); 67 } 68 69 namespace { 70 /// Emitter that uses dialect specific emitters to emit C++ code. 71 struct CppEmitter { 72 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); 73 74 /// Emits attribute or returns failure. 75 LogicalResult emitAttribute(Location loc, Attribute attr); 76 77 /// Emits operation 'op' with/without training semicolon or returns failure. 78 LogicalResult emitOperation(Operation &op, bool trailingSemicolon); 79 80 /// Emits type 'type' or returns failure. 81 LogicalResult emitType(Location loc, Type type); 82 83 /// Emits array of types as a std::tuple of the emitted types. 84 /// - emits void for an empty array; 85 /// - emits the type of the only element for arrays of size one; 86 /// - emits a std::tuple otherwise; 87 LogicalResult emitTypes(Location loc, ArrayRef<Type> types); 88 89 /// Emits array of types as a std::tuple of the emitted types independently of 90 /// the array size. 91 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); 92 93 /// Emits an assignment for a variable which has been declared previously. 94 LogicalResult emitVariableAssignment(OpResult result); 95 96 /// Emits a variable declaration for a result of an operation. 97 LogicalResult emitVariableDeclaration(OpResult result, 98 bool trailingSemicolon); 99 100 /// Emits the variable declaration and assignment prefix for 'op'. 101 /// - emits separate variable followed by std::tie for multi-valued operation; 102 /// - emits single type followed by variable for single result; 103 /// - emits nothing if no value produced by op; 104 /// Emits final '=' operator where a type is produced. Returns failure if 105 /// any result type could not be converted. 106 LogicalResult emitAssignPrefix(Operation &op); 107 108 /// Emits a label for the block. 109 LogicalResult emitLabel(Block &block); 110 111 /// Emits the operands and atttributes of the operation. All operands are 112 /// emitted first and then all attributes in alphabetical order. 113 LogicalResult emitOperandsAndAttributes(Operation &op, 114 ArrayRef<StringRef> exclude = {}); 115 116 /// Emits the operands of the operation. All operands are emitted in order. 117 LogicalResult emitOperands(Operation &op); 118 119 /// Return the existing or a new name for a Value. 120 StringRef getOrCreateName(Value val); 121 122 /// Return the existing or a new label of a Block. 123 StringRef getOrCreateName(Block &block); 124 125 /// Whether to map an mlir integer to a unsigned integer in C++. 126 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); 127 128 /// RAII helper function to manage entering/exiting C++ scopes. 129 struct Scope { 130 Scope(CppEmitter &emitter) 131 : valueMapperScope(emitter.valueMapper), 132 blockMapperScope(emitter.blockMapper), emitter(emitter) { 133 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); 134 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); 135 } 136 ~Scope() { 137 emitter.valueInScopeCount.pop(); 138 emitter.labelInScopeCount.pop(); 139 } 140 141 private: 142 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; 143 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; 144 CppEmitter &emitter; 145 }; 146 147 /// Returns wether the Value is assigned to a C++ variable in the scope. 148 bool hasValueInScope(Value val); 149 150 // Returns whether a label is assigned to the block. 151 bool hasBlockLabel(Block &block); 152 153 /// Returns the output stream. 154 raw_indented_ostream &ostream() { return os; }; 155 156 /// Returns if all variables for op results and basic block arguments need to 157 /// be declared at the beginning of a function. 158 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; 159 160 private: 161 using ValueMapper = llvm::ScopedHashTable<Value, std::string>; 162 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; 163 164 /// Output stream to emit to. 165 raw_indented_ostream os; 166 167 /// Boolean to enforce that all variables for op results and block 168 /// arguments are declared at the beginning of the function. This also 169 /// includes results from ops located in nested regions. 170 bool declareVariablesAtTop; 171 172 /// Map from value to name of C++ variable that contain the name. 173 ValueMapper valueMapper; 174 175 /// Map from block to name of C++ label. 176 BlockMapper blockMapper; 177 178 /// The number of values in the current scope. This is used to declare the 179 /// names of values in a scope. 180 std::stack<int64_t> valueInScopeCount; 181 std::stack<int64_t> labelInScopeCount; 182 }; 183 } // namespace 184 185 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, 186 Attribute value) { 187 OpResult result = operation->getResult(0); 188 189 // Only emit an assignment as the variable was already declared when printing 190 // the FuncOp. 191 if (emitter.shouldDeclareVariablesAtTop()) { 192 // Skip the assignment if the emitc.constant has no value. 193 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 194 if (oAttr.getValue().empty()) 195 return success(); 196 } 197 198 if (failed(emitter.emitVariableAssignment(result))) 199 return failure(); 200 return emitter.emitAttribute(operation->getLoc(), value); 201 } 202 203 // Emit a variable declaration for an emitc.constant op without value. 204 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 205 if (oAttr.getValue().empty()) 206 // The semicolon gets printed by the emitOperation function. 207 return emitter.emitVariableDeclaration(result, 208 /*trailingSemicolon=*/false); 209 } 210 211 // Emit a variable declaration. 212 if (failed(emitter.emitAssignPrefix(*operation))) 213 return failure(); 214 return emitter.emitAttribute(operation->getLoc(), value); 215 } 216 217 static LogicalResult printOperation(CppEmitter &emitter, 218 emitc::ConstantOp constantOp) { 219 Operation *operation = constantOp.getOperation(); 220 Attribute value = constantOp.getValue(); 221 222 return printConstantOp(emitter, operation, value); 223 } 224 225 static LogicalResult printOperation(CppEmitter &emitter, 226 emitc::VariableOp variableOp) { 227 Operation *operation = variableOp.getOperation(); 228 Attribute value = variableOp.getValue(); 229 230 return printConstantOp(emitter, operation, value); 231 } 232 233 static LogicalResult printOperation(CppEmitter &emitter, 234 arith::ConstantOp constantOp) { 235 Operation *operation = constantOp.getOperation(); 236 Attribute value = constantOp.getValue(); 237 238 return printConstantOp(emitter, operation, value); 239 } 240 241 static LogicalResult printOperation(CppEmitter &emitter, 242 func::ConstantOp constantOp) { 243 Operation *operation = constantOp.getOperation(); 244 Attribute value = constantOp.getValueAttr(); 245 246 return printConstantOp(emitter, operation, value); 247 } 248 249 static LogicalResult printOperation(CppEmitter &emitter, 250 cf::BranchOp branchOp) { 251 raw_ostream &os = emitter.ostream(); 252 Block &successor = *branchOp.getSuccessor(); 253 254 for (auto pair : 255 llvm::zip(branchOp.getOperands(), successor.getArguments())) { 256 Value &operand = std::get<0>(pair); 257 BlockArgument &argument = std::get<1>(pair); 258 os << emitter.getOrCreateName(argument) << " = " 259 << emitter.getOrCreateName(operand) << ";\n"; 260 } 261 262 os << "goto "; 263 if (!(emitter.hasBlockLabel(successor))) 264 return branchOp.emitOpError("unable to find label for successor block"); 265 os << emitter.getOrCreateName(successor); 266 return success(); 267 } 268 269 static LogicalResult printOperation(CppEmitter &emitter, 270 cf::CondBranchOp condBranchOp) { 271 raw_indented_ostream &os = emitter.ostream(); 272 Block &trueSuccessor = *condBranchOp.getTrueDest(); 273 Block &falseSuccessor = *condBranchOp.getFalseDest(); 274 275 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) 276 << ") {\n"; 277 278 os.indent(); 279 280 // If condition is true. 281 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), 282 trueSuccessor.getArguments())) { 283 Value &operand = std::get<0>(pair); 284 BlockArgument &argument = std::get<1>(pair); 285 os << emitter.getOrCreateName(argument) << " = " 286 << emitter.getOrCreateName(operand) << ";\n"; 287 } 288 289 os << "goto "; 290 if (!(emitter.hasBlockLabel(trueSuccessor))) { 291 return condBranchOp.emitOpError("unable to find label for successor block"); 292 } 293 os << emitter.getOrCreateName(trueSuccessor) << ";\n"; 294 os.unindent() << "} else {\n"; 295 os.indent(); 296 // If condition is false. 297 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), 298 falseSuccessor.getArguments())) { 299 Value &operand = std::get<0>(pair); 300 BlockArgument &argument = std::get<1>(pair); 301 os << emitter.getOrCreateName(argument) << " = " 302 << emitter.getOrCreateName(operand) << ";\n"; 303 } 304 305 os << "goto "; 306 if (!(emitter.hasBlockLabel(falseSuccessor))) { 307 return condBranchOp.emitOpError() 308 << "unable to find label for successor block"; 309 } 310 os << emitter.getOrCreateName(falseSuccessor) << ";\n"; 311 os.unindent() << "}"; 312 return success(); 313 } 314 315 static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { 316 if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) 317 return failure(); 318 319 raw_ostream &os = emitter.ostream(); 320 os << callOp.getCallee() << "("; 321 if (failed(emitter.emitOperands(*callOp.getOperation()))) 322 return failure(); 323 os << ")"; 324 return success(); 325 } 326 327 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { 328 raw_ostream &os = emitter.ostream(); 329 Operation &op = *callOp.getOperation(); 330 331 if (failed(emitter.emitAssignPrefix(op))) 332 return failure(); 333 os << callOp.getCallee(); 334 335 auto emitArgs = [&](Attribute attr) -> LogicalResult { 336 if (auto t = attr.dyn_cast<IntegerAttr>()) { 337 // Index attributes are treated specially as operand index. 338 if (t.getType().isIndex()) { 339 int64_t idx = t.getInt(); 340 if ((idx < 0) || (idx >= op.getNumOperands())) 341 return op.emitOpError("invalid operand index"); 342 if (!emitter.hasValueInScope(op.getOperand(idx))) 343 return op.emitOpError("operand ") 344 << idx << "'s value not defined in scope"; 345 os << emitter.getOrCreateName(op.getOperand(idx)); 346 return success(); 347 } 348 } 349 if (failed(emitter.emitAttribute(op.getLoc(), attr))) 350 return failure(); 351 352 return success(); 353 }; 354 355 if (callOp.getTemplateArgs()) { 356 os << "<"; 357 if (failed( 358 interleaveCommaWithError(*callOp.getTemplateArgs(), os, emitArgs))) 359 return failure(); 360 os << ">"; 361 } 362 363 os << "("; 364 365 LogicalResult emittedArgs = 366 callOp.getArgs() 367 ? interleaveCommaWithError(*callOp.getArgs(), os, emitArgs) 368 : emitter.emitOperands(op); 369 if (failed(emittedArgs)) 370 return failure(); 371 os << ")"; 372 return success(); 373 } 374 375 static LogicalResult printOperation(CppEmitter &emitter, 376 emitc::ApplyOp applyOp) { 377 raw_ostream &os = emitter.ostream(); 378 Operation &op = *applyOp.getOperation(); 379 380 if (failed(emitter.emitAssignPrefix(op))) 381 return failure(); 382 os << applyOp.getApplicableOperator(); 383 os << emitter.getOrCreateName(applyOp.getOperand()); 384 385 return success(); 386 } 387 388 static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { 389 raw_ostream &os = emitter.ostream(); 390 Operation &op = *castOp.getOperation(); 391 392 if (failed(emitter.emitAssignPrefix(op))) 393 return failure(); 394 os << "("; 395 if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) 396 return failure(); 397 os << ") "; 398 os << emitter.getOrCreateName(castOp.getOperand()); 399 400 return success(); 401 } 402 403 static LogicalResult printOperation(CppEmitter &emitter, 404 emitc::IncludeOp includeOp) { 405 raw_ostream &os = emitter.ostream(); 406 407 os << "#include "; 408 if (includeOp.getIsStandardInclude()) 409 os << "<" << includeOp.getInclude() << ">"; 410 else 411 os << "\"" << includeOp.getInclude() << "\""; 412 413 return success(); 414 } 415 416 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { 417 418 raw_indented_ostream &os = emitter.ostream(); 419 420 OperandRange operands = forOp.getIterOperands(); 421 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); 422 Operation::result_range results = forOp.getResults(); 423 424 if (!emitter.shouldDeclareVariablesAtTop()) { 425 for (OpResult result : results) { 426 if (failed(emitter.emitVariableDeclaration(result, 427 /*trailingSemicolon=*/true))) 428 return failure(); 429 } 430 } 431 432 for (auto pair : llvm::zip(iterArgs, operands)) { 433 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) 434 return failure(); 435 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; 436 os << emitter.getOrCreateName(std::get<1>(pair)) << ";"; 437 os << "\n"; 438 } 439 440 os << "for ("; 441 if (failed( 442 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) 443 return failure(); 444 os << " "; 445 os << emitter.getOrCreateName(forOp.getInductionVar()); 446 os << " = "; 447 os << emitter.getOrCreateName(forOp.getLowerBound()); 448 os << "; "; 449 os << emitter.getOrCreateName(forOp.getInductionVar()); 450 os << " < "; 451 os << emitter.getOrCreateName(forOp.getUpperBound()); 452 os << "; "; 453 os << emitter.getOrCreateName(forOp.getInductionVar()); 454 os << " += "; 455 os << emitter.getOrCreateName(forOp.getStep()); 456 os << ") {\n"; 457 os.indent(); 458 459 Region &forRegion = forOp.getRegion(); 460 auto regionOps = forRegion.getOps(); 461 462 // We skip the trailing yield op because this updates the result variables 463 // of the for op in the generated code. Instead we update the iterArgs at 464 // the end of a loop iteration and set the result variables after the for 465 // loop. 466 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { 467 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 468 return failure(); 469 } 470 471 Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); 472 // Copy yield operands into iterArgs at the end of a loop iteration. 473 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { 474 BlockArgument iterArg = std::get<0>(pair); 475 Value operand = std::get<1>(pair); 476 os << emitter.getOrCreateName(iterArg) << " = " 477 << emitter.getOrCreateName(operand) << ";\n"; 478 } 479 480 os.unindent() << "}"; 481 482 // Copy iterArgs into results after the for loop. 483 for (auto pair : llvm::zip(results, iterArgs)) { 484 OpResult result = std::get<0>(pair); 485 BlockArgument iterArg = std::get<1>(pair); 486 os << "\n" 487 << emitter.getOrCreateName(result) << " = " 488 << emitter.getOrCreateName(iterArg) << ";"; 489 } 490 491 return success(); 492 } 493 494 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { 495 raw_indented_ostream &os = emitter.ostream(); 496 497 if (!emitter.shouldDeclareVariablesAtTop()) { 498 for (OpResult result : ifOp.getResults()) { 499 if (failed(emitter.emitVariableDeclaration(result, 500 /*trailingSemicolon=*/true))) 501 return failure(); 502 } 503 } 504 505 os << "if ("; 506 if (failed(emitter.emitOperands(*ifOp.getOperation()))) 507 return failure(); 508 os << ") {\n"; 509 os.indent(); 510 511 Region &thenRegion = ifOp.getThenRegion(); 512 for (Operation &op : thenRegion.getOps()) { 513 // Note: This prints a superfluous semicolon if the terminating yield op has 514 // zero results. 515 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 516 return failure(); 517 } 518 519 os.unindent() << "}"; 520 521 Region &elseRegion = ifOp.getElseRegion(); 522 if (!elseRegion.empty()) { 523 os << " else {\n"; 524 os.indent(); 525 526 for (Operation &op : elseRegion.getOps()) { 527 // Note: This prints a superfluous semicolon if the terminating yield op 528 // has zero results. 529 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 530 return failure(); 531 } 532 533 os.unindent() << "}"; 534 } 535 536 return success(); 537 } 538 539 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { 540 raw_ostream &os = emitter.ostream(); 541 Operation &parentOp = *yieldOp.getOperation()->getParentOp(); 542 543 if (yieldOp.getNumOperands() != parentOp.getNumResults()) { 544 return yieldOp.emitError("number of operands does not to match the number " 545 "of the parent op's results"); 546 } 547 548 if (failed(interleaveWithError( 549 llvm::zip(parentOp.getResults(), yieldOp.getOperands()), 550 [&](auto pair) -> LogicalResult { 551 auto result = std::get<0>(pair); 552 auto operand = std::get<1>(pair); 553 os << emitter.getOrCreateName(result) << " = "; 554 555 if (!emitter.hasValueInScope(operand)) 556 return yieldOp.emitError("operand value not in scope"); 557 os << emitter.getOrCreateName(operand); 558 return success(); 559 }, 560 [&]() { os << ";\n"; }))) 561 return failure(); 562 563 return success(); 564 } 565 566 static LogicalResult printOperation(CppEmitter &emitter, 567 func::ReturnOp returnOp) { 568 raw_ostream &os = emitter.ostream(); 569 os << "return"; 570 switch (returnOp.getNumOperands()) { 571 case 0: 572 return success(); 573 case 1: 574 os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); 575 return success(emitter.hasValueInScope(returnOp.getOperand(0))); 576 default: 577 os << " std::make_tuple("; 578 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) 579 return failure(); 580 os << ")"; 581 return success(); 582 } 583 } 584 585 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { 586 CppEmitter::Scope scope(emitter); 587 588 for (Operation &op : moduleOp) { 589 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) 590 return failure(); 591 } 592 return success(); 593 } 594 595 static LogicalResult printOperation(CppEmitter &emitter, 596 func::FuncOp functionOp) { 597 // We need to declare variables at top if the function has multiple blocks. 598 if (!emitter.shouldDeclareVariablesAtTop() && 599 functionOp.getBlocks().size() > 1) { 600 return functionOp.emitOpError( 601 "with multiple blocks needs variables declared at top"); 602 } 603 604 CppEmitter::Scope scope(emitter); 605 raw_indented_ostream &os = emitter.ostream(); 606 if (failed(emitter.emitTypes(functionOp.getLoc(), 607 functionOp.getFunctionType().getResults()))) 608 return failure(); 609 os << " " << functionOp.getName(); 610 611 os << "("; 612 if (failed(interleaveCommaWithError( 613 functionOp.getArguments(), os, 614 [&](BlockArgument arg) -> LogicalResult { 615 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) 616 return failure(); 617 os << " " << emitter.getOrCreateName(arg); 618 return success(); 619 }))) 620 return failure(); 621 os << ") {\n"; 622 os.indent(); 623 if (emitter.shouldDeclareVariablesAtTop()) { 624 // Declare all variables that hold op results including those from nested 625 // regions. 626 WalkResult result = 627 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { 628 for (OpResult result : op->getResults()) { 629 if (failed(emitter.emitVariableDeclaration( 630 result, /*trailingSemicolon=*/true))) { 631 return WalkResult( 632 op->emitError("unable to declare result variable for op")); 633 } 634 } 635 return WalkResult::advance(); 636 }); 637 if (result.wasInterrupted()) 638 return failure(); 639 } 640 641 Region::BlockListType &blocks = functionOp.getBlocks(); 642 // Create label names for basic blocks. 643 for (Block &block : blocks) { 644 emitter.getOrCreateName(block); 645 } 646 647 // Declare variables for basic block arguments. 648 for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { 649 Block &block = *it; 650 for (BlockArgument &arg : block.getArguments()) { 651 if (emitter.hasValueInScope(arg)) 652 return functionOp.emitOpError(" block argument #") 653 << arg.getArgNumber() << " is out of scope"; 654 if (failed( 655 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { 656 return failure(); 657 } 658 os << " " << emitter.getOrCreateName(arg) << ";\n"; 659 } 660 } 661 662 for (Block &block : blocks) { 663 // Only print a label if the block has predecessors. 664 if (!block.hasNoPredecessors()) { 665 if (failed(emitter.emitLabel(block))) 666 return failure(); 667 } 668 for (Operation &op : block.getOperations()) { 669 // When generating code for an scf.if or cf.cond_br op no semicolon needs 670 // to be printed after the closing brace. 671 // When generating code for an scf.for op, printing a trailing semicolon 672 // is handled within the printOperation function. 673 bool trailingSemicolon = 674 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op); 675 676 if (failed(emitter.emitOperation( 677 op, /*trailingSemicolon=*/trailingSemicolon))) 678 return failure(); 679 } 680 } 681 os.unindent() << "}\n"; 682 return success(); 683 } 684 685 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) 686 : os(os), declareVariablesAtTop(declareVariablesAtTop) { 687 valueInScopeCount.push(0); 688 labelInScopeCount.push(0); 689 } 690 691 /// Return the existing or a new name for a Value. 692 StringRef CppEmitter::getOrCreateName(Value val) { 693 if (!valueMapper.count(val)) 694 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); 695 return *valueMapper.begin(val); 696 } 697 698 /// Return the existing or a new label for a Block. 699 StringRef CppEmitter::getOrCreateName(Block &block) { 700 if (!blockMapper.count(&block)) 701 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); 702 return *blockMapper.begin(&block); 703 } 704 705 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { 706 switch (val) { 707 case IntegerType::Signless: 708 return false; 709 case IntegerType::Signed: 710 return false; 711 case IntegerType::Unsigned: 712 return true; 713 } 714 llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); 715 } 716 717 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } 718 719 bool CppEmitter::hasBlockLabel(Block &block) { 720 return blockMapper.count(&block); 721 } 722 723 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { 724 auto printInt = [&](const APInt &val, bool isUnsigned) { 725 if (val.getBitWidth() == 1) { 726 if (val.getBoolValue()) 727 os << "true"; 728 else 729 os << "false"; 730 } else { 731 SmallString<128> strValue; 732 val.toString(strValue, 10, !isUnsigned, false); 733 os << strValue; 734 } 735 }; 736 737 auto printFloat = [&](const APFloat &val) { 738 if (val.isFinite()) { 739 SmallString<128> strValue; 740 // Use default values of toString except don't truncate zeros. 741 val.toString(strValue, 0, 0, false); 742 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { 743 case llvm::APFloatBase::S_IEEEsingle: 744 os << "(float)"; 745 break; 746 case llvm::APFloatBase::S_IEEEdouble: 747 os << "(double)"; 748 break; 749 default: 750 break; 751 }; 752 os << strValue; 753 } else if (val.isNaN()) { 754 os << "NAN"; 755 } else if (val.isInfinity()) { 756 if (val.isNegative()) 757 os << "-"; 758 os << "INFINITY"; 759 } 760 }; 761 762 // Print floating point attributes. 763 if (auto fAttr = attr.dyn_cast<FloatAttr>()) { 764 printFloat(fAttr.getValue()); 765 return success(); 766 } 767 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) { 768 os << '{'; 769 interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); 770 os << '}'; 771 return success(); 772 } 773 774 // Print integer attributes. 775 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) { 776 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) { 777 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); 778 return success(); 779 } 780 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) { 781 printInt(iAttr.getValue(), false); 782 return success(); 783 } 784 } 785 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) { 786 if (auto iType = dense.getType() 787 .cast<TensorType>() 788 .getElementType() 789 .dyn_cast<IntegerType>()) { 790 os << '{'; 791 interleaveComma(dense, os, [&](const APInt &val) { 792 printInt(val, shouldMapToUnsigned(iType.getSignedness())); 793 }); 794 os << '}'; 795 return success(); 796 } 797 if (auto iType = dense.getType() 798 .cast<TensorType>() 799 .getElementType() 800 .dyn_cast<IndexType>()) { 801 os << '{'; 802 interleaveComma(dense, os, 803 [&](const APInt &val) { printInt(val, false); }); 804 os << '}'; 805 return success(); 806 } 807 } 808 809 // Print opaque attributes. 810 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) { 811 os << oAttr.getValue(); 812 return success(); 813 } 814 815 // Print symbolic reference attributes. 816 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) { 817 if (sAttr.getNestedReferences().size() > 1) 818 return emitError(loc, "attribute has more than 1 nested reference"); 819 os << sAttr.getRootReference().getValue(); 820 return success(); 821 } 822 823 // Print type attributes. 824 if (auto type = attr.dyn_cast<TypeAttr>()) 825 return emitType(loc, type.getValue()); 826 827 return emitError(loc, "cannot emit attribute of type ") << attr.getType(); 828 } 829 830 LogicalResult CppEmitter::emitOperands(Operation &op) { 831 auto emitOperandName = [&](Value result) -> LogicalResult { 832 if (!hasValueInScope(result)) 833 return op.emitOpError() << "operand value not in scope"; 834 os << getOrCreateName(result); 835 return success(); 836 }; 837 return interleaveCommaWithError(op.getOperands(), os, emitOperandName); 838 } 839 840 LogicalResult 841 CppEmitter::emitOperandsAndAttributes(Operation &op, 842 ArrayRef<StringRef> exclude) { 843 if (failed(emitOperands(op))) 844 return failure(); 845 // Insert comma in between operands and non-filtered attributes if needed. 846 if (op.getNumOperands() > 0) { 847 for (NamedAttribute attr : op.getAttrs()) { 848 if (!llvm::is_contained(exclude, attr.getName().strref())) { 849 os << ", "; 850 break; 851 } 852 } 853 } 854 // Emit attributes. 855 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { 856 if (llvm::is_contained(exclude, attr.getName().strref())) 857 return success(); 858 os << "/* " << attr.getName().getValue() << " */"; 859 if (failed(emitAttribute(op.getLoc(), attr.getValue()))) 860 return failure(); 861 return success(); 862 }; 863 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); 864 } 865 866 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { 867 if (!hasValueInScope(result)) { 868 return result.getDefiningOp()->emitOpError( 869 "result variable for the operation has not been declared"); 870 } 871 os << getOrCreateName(result) << " = "; 872 return success(); 873 } 874 875 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, 876 bool trailingSemicolon) { 877 if (hasValueInScope(result)) { 878 return result.getDefiningOp()->emitError( 879 "result variable for the operation already declared"); 880 } 881 if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) 882 return failure(); 883 os << " " << getOrCreateName(result); 884 if (trailingSemicolon) 885 os << ";\n"; 886 return success(); 887 } 888 889 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { 890 switch (op.getNumResults()) { 891 case 0: 892 break; 893 case 1: { 894 OpResult result = op.getResult(0); 895 if (shouldDeclareVariablesAtTop()) { 896 if (failed(emitVariableAssignment(result))) 897 return failure(); 898 } else { 899 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) 900 return failure(); 901 os << " = "; 902 } 903 break; 904 } 905 default: 906 if (!shouldDeclareVariablesAtTop()) { 907 for (OpResult result : op.getResults()) { 908 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) 909 return failure(); 910 } 911 } 912 os << "std::tie("; 913 interleaveComma(op.getResults(), os, 914 [&](Value result) { os << getOrCreateName(result); }); 915 os << ") = "; 916 } 917 return success(); 918 } 919 920 LogicalResult CppEmitter::emitLabel(Block &block) { 921 if (!hasBlockLabel(block)) 922 return block.getParentOp()->emitError("label for block not found"); 923 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block 924 // label instead of using `getOStream`. 925 os.getOStream() << getOrCreateName(block) << ":\n"; 926 return success(); 927 } 928 929 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { 930 LogicalResult status = 931 llvm::TypeSwitch<Operation *, LogicalResult>(&op) 932 // Builtin ops. 933 .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); }) 934 // CF ops. 935 .Case<cf::BranchOp, cf::CondBranchOp>( 936 [&](auto op) { return printOperation(*this, op); }) 937 // EmitC ops. 938 .Case<emitc::ApplyOp, emitc::CallOp, emitc::CastOp, emitc::ConstantOp, 939 emitc::IncludeOp, emitc::VariableOp>( 940 [&](auto op) { return printOperation(*this, op); }) 941 // Func ops. 942 .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>( 943 [&](auto op) { return printOperation(*this, op); }) 944 // SCF ops. 945 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>( 946 [&](auto op) { return printOperation(*this, op); }) 947 // Arithmetic ops. 948 .Case<arith::ConstantOp>( 949 [&](auto op) { return printOperation(*this, op); }) 950 .Default([&](Operation *) { 951 return op.emitOpError("unable to find printer for op"); 952 }); 953 954 if (failed(status)) 955 return failure(); 956 os << (trailingSemicolon ? ";\n" : "\n"); 957 return success(); 958 } 959 960 LogicalResult CppEmitter::emitType(Location loc, Type type) { 961 if (auto iType = type.dyn_cast<IntegerType>()) { 962 switch (iType.getWidth()) { 963 case 1: 964 return (os << "bool"), success(); 965 case 8: 966 case 16: 967 case 32: 968 case 64: 969 if (shouldMapToUnsigned(iType.getSignedness())) 970 return (os << "uint" << iType.getWidth() << "_t"), success(); 971 else 972 return (os << "int" << iType.getWidth() << "_t"), success(); 973 default: 974 return emitError(loc, "cannot emit integer type ") << type; 975 } 976 } 977 if (auto fType = type.dyn_cast<FloatType>()) { 978 switch (fType.getWidth()) { 979 case 32: 980 return (os << "float"), success(); 981 case 64: 982 return (os << "double"), success(); 983 default: 984 return emitError(loc, "cannot emit float type ") << type; 985 } 986 } 987 if (auto iType = type.dyn_cast<IndexType>()) 988 return (os << "size_t"), success(); 989 if (auto tType = type.dyn_cast<TensorType>()) { 990 if (!tType.hasRank()) 991 return emitError(loc, "cannot emit unranked tensor type"); 992 if (!tType.hasStaticShape()) 993 return emitError(loc, "cannot emit tensor type with non static shape"); 994 os << "Tensor<"; 995 if (failed(emitType(loc, tType.getElementType()))) 996 return failure(); 997 auto shape = tType.getShape(); 998 for (auto dimSize : shape) { 999 os << ", "; 1000 os << dimSize; 1001 } 1002 os << ">"; 1003 return success(); 1004 } 1005 if (auto tType = type.dyn_cast<TupleType>()) 1006 return emitTupleType(loc, tType.getTypes()); 1007 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) { 1008 os << oType.getValue(); 1009 return success(); 1010 } 1011 if (auto pType = type.dyn_cast<emitc::PointerType>()) { 1012 if (failed(emitType(loc, pType.getPointee()))) 1013 return failure(); 1014 os << "*"; 1015 return success(); 1016 } 1017 return emitError(loc, "cannot emit type ") << type; 1018 } 1019 1020 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { 1021 switch (types.size()) { 1022 case 0: 1023 os << "void"; 1024 return success(); 1025 case 1: 1026 return emitType(loc, types.front()); 1027 default: 1028 return emitTupleType(loc, types); 1029 } 1030 } 1031 1032 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { 1033 os << "std::tuple<"; 1034 if (failed(interleaveCommaWithError( 1035 types, os, [&](Type type) { return emitType(loc, type); }))) 1036 return failure(); 1037 os << ">"; 1038 return success(); 1039 } 1040 1041 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, 1042 bool declareVariablesAtTop) { 1043 CppEmitter emitter(os, declareVariablesAtTop); 1044 return emitter.emitOperation(*op, /*trailingSemicolon=*/false); 1045 } 1046