1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 10 #include "mlir/Dialect/EmitC/IR/EmitC.h" 11 #include "mlir/Dialect/SCF/SCF.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/Dialect.h" 16 #include "mlir/IR/Operation.h" 17 #include "mlir/Support/IndentedOstream.h" 18 #include "mlir/Target/Cpp/CppEmitter.h" 19 #include "llvm/ADT/DenseMap.h" 20 #include "llvm/ADT/StringExtras.h" 21 #include "llvm/ADT/StringMap.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <utility> 26 27 #define DEBUG_TYPE "translate-to-cpp" 28 29 using namespace mlir; 30 using namespace mlir::emitc; 31 using llvm::formatv; 32 33 /// Convenience functions to produce interleaved output with functions returning 34 /// a LogicalResult. This is different than those in STLExtras as functions used 35 /// on each element doesn't return a string. 36 template <typename ForwardIterator, typename UnaryFunctor, 37 typename NullaryFunctor> 38 inline LogicalResult 39 interleaveWithError(ForwardIterator begin, ForwardIterator end, 40 UnaryFunctor eachFn, NullaryFunctor betweenFn) { 41 if (begin == end) 42 return success(); 43 if (failed(eachFn(*begin))) 44 return failure(); 45 ++begin; 46 for (; begin != end; ++begin) { 47 betweenFn(); 48 if (failed(eachFn(*begin))) 49 return failure(); 50 } 51 return success(); 52 } 53 54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor> 55 inline LogicalResult interleaveWithError(const Container &c, 56 UnaryFunctor eachFn, 57 NullaryFunctor betweenFn) { 58 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); 59 } 60 61 template <typename Container, typename UnaryFunctor> 62 inline LogicalResult interleaveCommaWithError(const Container &c, 63 raw_ostream &os, 64 UnaryFunctor eachFn) { 65 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); 66 } 67 68 namespace { 69 /// Emitter that uses dialect specific emitters to emit C++ code. 70 struct CppEmitter { 71 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); 72 73 /// Emits attribute or returns failure. 74 LogicalResult emitAttribute(Location loc, Attribute attr); 75 76 /// Emits operation 'op' with/without training semicolon or returns failure. 77 LogicalResult emitOperation(Operation &op, bool trailingSemicolon); 78 79 /// Emits type 'type' or returns failure. 80 LogicalResult emitType(Location loc, Type type); 81 82 /// Emits array of types as a std::tuple of the emitted types. 83 /// - emits void for an empty array; 84 /// - emits the type of the only element for arrays of size one; 85 /// - emits a std::tuple otherwise; 86 LogicalResult emitTypes(Location loc, ArrayRef<Type> types); 87 88 /// Emits array of types as a std::tuple of the emitted types independently of 89 /// the array size. 90 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); 91 92 /// Emits an assignment for a variable which has been declared previously. 93 LogicalResult emitVariableAssignment(OpResult result); 94 95 /// Emits a variable declaration for a result of an operation. 96 LogicalResult emitVariableDeclaration(OpResult result, 97 bool trailingSemicolon); 98 99 /// Emits the variable declaration and assignment prefix for 'op'. 100 /// - emits separate variable followed by std::tie for multi-valued operation; 101 /// - emits single type followed by variable for single result; 102 /// - emits nothing if no value produced by op; 103 /// Emits final '=' operator where a type is produced. Returns failure if 104 /// any result type could not be converted. 105 LogicalResult emitAssignPrefix(Operation &op); 106 107 /// Emits a label for the block. 108 LogicalResult emitLabel(Block &block); 109 110 /// Emits the operands and atttributes of the operation. All operands are 111 /// emitted first and then all attributes in alphabetical order. 112 LogicalResult emitOperandsAndAttributes(Operation &op, 113 ArrayRef<StringRef> exclude = {}); 114 115 /// Emits the operands of the operation. All operands are emitted in order. 116 LogicalResult emitOperands(Operation &op); 117 118 /// Return the existing or a new name for a Value. 119 StringRef getOrCreateName(Value val); 120 121 /// Return the existing or a new label of a Block. 122 StringRef getOrCreateName(Block &block); 123 124 /// Whether to map an mlir integer to a unsigned integer in C++. 125 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); 126 127 /// RAII helper function to manage entering/exiting C++ scopes. 128 struct Scope { 129 Scope(CppEmitter &emitter) 130 : valueMapperScope(emitter.valueMapper), 131 blockMapperScope(emitter.blockMapper), emitter(emitter) { 132 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); 133 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); 134 } 135 ~Scope() { 136 emitter.valueInScopeCount.pop(); 137 emitter.labelInScopeCount.pop(); 138 } 139 140 private: 141 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; 142 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; 143 CppEmitter &emitter; 144 }; 145 146 /// Returns wether the Value is assigned to a C++ variable in the scope. 147 bool hasValueInScope(Value val); 148 149 // Returns whether a label is assigned to the block. 150 bool hasBlockLabel(Block &block); 151 152 /// Returns the output stream. 153 raw_indented_ostream &ostream() { return os; }; 154 155 /// Returns if all variables for op results and basic block arguments need to 156 /// be declared at the beginning of a function. 157 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; 158 159 private: 160 using ValueMapper = llvm::ScopedHashTable<Value, std::string>; 161 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; 162 163 /// Output stream to emit to. 164 raw_indented_ostream os; 165 166 /// Boolean to enforce that all variables for op results and block 167 /// arguments are declared at the beginning of the function. This also 168 /// includes results from ops located in nested regions. 169 bool declareVariablesAtTop; 170 171 /// Map from value to name of C++ variable that contain the name. 172 ValueMapper valueMapper; 173 174 /// Map from block to name of C++ label. 175 BlockMapper blockMapper; 176 177 /// The number of values in the current scope. This is used to declare the 178 /// names of values in a scope. 179 std::stack<int64_t> valueInScopeCount; 180 std::stack<int64_t> labelInScopeCount; 181 }; 182 } // namespace 183 184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, 185 Attribute value) { 186 OpResult result = operation->getResult(0); 187 188 // Only emit an assignment as the variable was already declared when printing 189 // the FuncOp. 190 if (emitter.shouldDeclareVariablesAtTop()) { 191 // Skip the assignment if the emitc.constant has no value. 192 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 193 if (oAttr.getValue().empty()) 194 return success(); 195 } 196 197 if (failed(emitter.emitVariableAssignment(result))) 198 return failure(); 199 return emitter.emitAttribute(operation->getLoc(), value); 200 } 201 202 // Emit a variable declaration for an emitc.constant op without value. 203 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 204 if (oAttr.getValue().empty()) 205 // The semicolon gets printed by the emitOperation function. 206 return emitter.emitVariableDeclaration(result, 207 /*trailingSemicolon=*/false); 208 } 209 210 // Emit a variable declaration. 211 if (failed(emitter.emitAssignPrefix(*operation))) 212 return failure(); 213 return emitter.emitAttribute(operation->getLoc(), value); 214 } 215 216 static LogicalResult printOperation(CppEmitter &emitter, 217 emitc::ConstantOp constantOp) { 218 Operation *operation = constantOp.getOperation(); 219 Attribute value = constantOp.value(); 220 221 return printConstantOp(emitter, operation, value); 222 } 223 224 static LogicalResult printOperation(CppEmitter &emitter, 225 arith::ConstantOp constantOp) { 226 Operation *operation = constantOp.getOperation(); 227 Attribute value = constantOp.getValue(); 228 229 return printConstantOp(emitter, operation, value); 230 } 231 232 static LogicalResult printOperation(CppEmitter &emitter, 233 mlir::ConstantOp constantOp) { 234 Operation *operation = constantOp.getOperation(); 235 Attribute value = constantOp.getValueAttr(); 236 237 return printConstantOp(emitter, operation, value); 238 } 239 240 static LogicalResult printOperation(CppEmitter &emitter, 241 cf::BranchOp branchOp) { 242 raw_ostream &os = emitter.ostream(); 243 Block &successor = *branchOp.getSuccessor(); 244 245 for (auto pair : 246 llvm::zip(branchOp.getOperands(), successor.getArguments())) { 247 Value &operand = std::get<0>(pair); 248 BlockArgument &argument = std::get<1>(pair); 249 os << emitter.getOrCreateName(argument) << " = " 250 << emitter.getOrCreateName(operand) << ";\n"; 251 } 252 253 os << "goto "; 254 if (!(emitter.hasBlockLabel(successor))) 255 return branchOp.emitOpError("unable to find label for successor block"); 256 os << emitter.getOrCreateName(successor); 257 return success(); 258 } 259 260 static LogicalResult printOperation(CppEmitter &emitter, 261 cf::CondBranchOp condBranchOp) { 262 raw_indented_ostream &os = emitter.ostream(); 263 Block &trueSuccessor = *condBranchOp.getTrueDest(); 264 Block &falseSuccessor = *condBranchOp.getFalseDest(); 265 266 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) 267 << ") {\n"; 268 269 os.indent(); 270 271 // If condition is true. 272 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), 273 trueSuccessor.getArguments())) { 274 Value &operand = std::get<0>(pair); 275 BlockArgument &argument = std::get<1>(pair); 276 os << emitter.getOrCreateName(argument) << " = " 277 << emitter.getOrCreateName(operand) << ";\n"; 278 } 279 280 os << "goto "; 281 if (!(emitter.hasBlockLabel(trueSuccessor))) { 282 return condBranchOp.emitOpError("unable to find label for successor block"); 283 } 284 os << emitter.getOrCreateName(trueSuccessor) << ";\n"; 285 os.unindent() << "} else {\n"; 286 os.indent(); 287 // If condition is false. 288 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), 289 falseSuccessor.getArguments())) { 290 Value &operand = std::get<0>(pair); 291 BlockArgument &argument = std::get<1>(pair); 292 os << emitter.getOrCreateName(argument) << " = " 293 << emitter.getOrCreateName(operand) << ";\n"; 294 } 295 296 os << "goto "; 297 if (!(emitter.hasBlockLabel(falseSuccessor))) { 298 return condBranchOp.emitOpError() 299 << "unable to find label for successor block"; 300 } 301 os << emitter.getOrCreateName(falseSuccessor) << ";\n"; 302 os.unindent() << "}"; 303 return success(); 304 } 305 306 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) { 307 if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) 308 return failure(); 309 310 raw_ostream &os = emitter.ostream(); 311 os << callOp.getCallee() << "("; 312 if (failed(emitter.emitOperands(*callOp.getOperation()))) 313 return failure(); 314 os << ")"; 315 return success(); 316 } 317 318 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { 319 raw_ostream &os = emitter.ostream(); 320 Operation &op = *callOp.getOperation(); 321 322 if (failed(emitter.emitAssignPrefix(op))) 323 return failure(); 324 os << callOp.callee(); 325 326 auto emitArgs = [&](Attribute attr) -> LogicalResult { 327 if (auto t = attr.dyn_cast<IntegerAttr>()) { 328 // Index attributes are treated specially as operand index. 329 if (t.getType().isIndex()) { 330 int64_t idx = t.getInt(); 331 if ((idx < 0) || (idx >= op.getNumOperands())) 332 return op.emitOpError("invalid operand index"); 333 if (!emitter.hasValueInScope(op.getOperand(idx))) 334 return op.emitOpError("operand ") 335 << idx << "'s value not defined in scope"; 336 os << emitter.getOrCreateName(op.getOperand(idx)); 337 return success(); 338 } 339 } 340 if (failed(emitter.emitAttribute(op.getLoc(), attr))) 341 return failure(); 342 343 return success(); 344 }; 345 346 if (callOp.template_args()) { 347 os << "<"; 348 if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs))) 349 return failure(); 350 os << ">"; 351 } 352 353 os << "("; 354 355 LogicalResult emittedArgs = 356 callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs) 357 : emitter.emitOperands(op); 358 if (failed(emittedArgs)) 359 return failure(); 360 os << ")"; 361 return success(); 362 } 363 364 static LogicalResult printOperation(CppEmitter &emitter, 365 emitc::ApplyOp applyOp) { 366 raw_ostream &os = emitter.ostream(); 367 Operation &op = *applyOp.getOperation(); 368 369 if (failed(emitter.emitAssignPrefix(op))) 370 return failure(); 371 os << applyOp.applicableOperator(); 372 os << emitter.getOrCreateName(applyOp.getOperand()); 373 374 return success(); 375 } 376 377 static LogicalResult printOperation(CppEmitter &emitter, 378 emitc::IncludeOp includeOp) { 379 raw_ostream &os = emitter.ostream(); 380 381 os << "#include "; 382 if (includeOp.is_standard_include()) 383 os << "<" << includeOp.include() << ">"; 384 else 385 os << "\"" << includeOp.include() << "\""; 386 387 return success(); 388 } 389 390 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { 391 392 raw_indented_ostream &os = emitter.ostream(); 393 394 OperandRange operands = forOp.getIterOperands(); 395 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); 396 Operation::result_range results = forOp.getResults(); 397 398 if (!emitter.shouldDeclareVariablesAtTop()) { 399 for (OpResult result : results) { 400 if (failed(emitter.emitVariableDeclaration(result, 401 /*trailingSemicolon=*/true))) 402 return failure(); 403 } 404 } 405 406 for (auto pair : llvm::zip(iterArgs, operands)) { 407 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) 408 return failure(); 409 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; 410 os << emitter.getOrCreateName(std::get<1>(pair)) << ";"; 411 os << "\n"; 412 } 413 414 os << "for ("; 415 if (failed( 416 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) 417 return failure(); 418 os << " "; 419 os << emitter.getOrCreateName(forOp.getInductionVar()); 420 os << " = "; 421 os << emitter.getOrCreateName(forOp.getLowerBound()); 422 os << "; "; 423 os << emitter.getOrCreateName(forOp.getInductionVar()); 424 os << " < "; 425 os << emitter.getOrCreateName(forOp.getUpperBound()); 426 os << "; "; 427 os << emitter.getOrCreateName(forOp.getInductionVar()); 428 os << " += "; 429 os << emitter.getOrCreateName(forOp.getStep()); 430 os << ") {\n"; 431 os.indent(); 432 433 Region &forRegion = forOp.getRegion(); 434 auto regionOps = forRegion.getOps(); 435 436 // We skip the trailing yield op because this updates the result variables 437 // of the for op in the generated code. Instead we update the iterArgs at 438 // the end of a loop iteration and set the result variables after the for 439 // loop. 440 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { 441 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 442 return failure(); 443 } 444 445 Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); 446 // Copy yield operands into iterArgs at the end of a loop iteration. 447 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { 448 BlockArgument iterArg = std::get<0>(pair); 449 Value operand = std::get<1>(pair); 450 os << emitter.getOrCreateName(iterArg) << " = " 451 << emitter.getOrCreateName(operand) << ";\n"; 452 } 453 454 os.unindent() << "}"; 455 456 // Copy iterArgs into results after the for loop. 457 for (auto pair : llvm::zip(results, iterArgs)) { 458 OpResult result = std::get<0>(pair); 459 BlockArgument iterArg = std::get<1>(pair); 460 os << "\n" 461 << emitter.getOrCreateName(result) << " = " 462 << emitter.getOrCreateName(iterArg) << ";"; 463 } 464 465 return success(); 466 } 467 468 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { 469 raw_indented_ostream &os = emitter.ostream(); 470 471 if (!emitter.shouldDeclareVariablesAtTop()) { 472 for (OpResult result : ifOp.getResults()) { 473 if (failed(emitter.emitVariableDeclaration(result, 474 /*trailingSemicolon=*/true))) 475 return failure(); 476 } 477 } 478 479 os << "if ("; 480 if (failed(emitter.emitOperands(*ifOp.getOperation()))) 481 return failure(); 482 os << ") {\n"; 483 os.indent(); 484 485 Region &thenRegion = ifOp.getThenRegion(); 486 for (Operation &op : thenRegion.getOps()) { 487 // Note: This prints a superfluous semicolon if the terminating yield op has 488 // zero results. 489 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 490 return failure(); 491 } 492 493 os.unindent() << "}"; 494 495 Region &elseRegion = ifOp.getElseRegion(); 496 if (!elseRegion.empty()) { 497 os << " else {\n"; 498 os.indent(); 499 500 for (Operation &op : elseRegion.getOps()) { 501 // Note: This prints a superfluous semicolon if the terminating yield op 502 // has zero results. 503 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 504 return failure(); 505 } 506 507 os.unindent() << "}"; 508 } 509 510 return success(); 511 } 512 513 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { 514 raw_ostream &os = emitter.ostream(); 515 Operation &parentOp = *yieldOp.getOperation()->getParentOp(); 516 517 if (yieldOp.getNumOperands() != parentOp.getNumResults()) { 518 return yieldOp.emitError("number of operands does not to match the number " 519 "of the parent op's results"); 520 } 521 522 if (failed(interleaveWithError( 523 llvm::zip(parentOp.getResults(), yieldOp.getOperands()), 524 [&](auto pair) -> LogicalResult { 525 auto result = std::get<0>(pair); 526 auto operand = std::get<1>(pair); 527 os << emitter.getOrCreateName(result) << " = "; 528 529 if (!emitter.hasValueInScope(operand)) 530 return yieldOp.emitError("operand value not in scope"); 531 os << emitter.getOrCreateName(operand); 532 return success(); 533 }, 534 [&]() { os << ";\n"; }))) 535 return failure(); 536 537 return success(); 538 } 539 540 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) { 541 raw_ostream &os = emitter.ostream(); 542 os << "return"; 543 switch (returnOp.getNumOperands()) { 544 case 0: 545 return success(); 546 case 1: 547 os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); 548 return success(emitter.hasValueInScope(returnOp.getOperand(0))); 549 default: 550 os << " std::make_tuple("; 551 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) 552 return failure(); 553 os << ")"; 554 return success(); 555 } 556 } 557 558 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { 559 CppEmitter::Scope scope(emitter); 560 561 for (Operation &op : moduleOp) { 562 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) 563 return failure(); 564 } 565 return success(); 566 } 567 568 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) { 569 // We need to declare variables at top if the function has multiple blocks. 570 if (!emitter.shouldDeclareVariablesAtTop() && 571 functionOp.getBlocks().size() > 1) { 572 return functionOp.emitOpError( 573 "with multiple blocks needs variables declared at top"); 574 } 575 576 CppEmitter::Scope scope(emitter); 577 raw_indented_ostream &os = emitter.ostream(); 578 if (failed(emitter.emitTypes(functionOp.getLoc(), 579 functionOp.getType().getResults()))) 580 return failure(); 581 os << " " << functionOp.getName(); 582 583 os << "("; 584 if (failed(interleaveCommaWithError( 585 functionOp.getArguments(), os, 586 [&](BlockArgument arg) -> LogicalResult { 587 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) 588 return failure(); 589 os << " " << emitter.getOrCreateName(arg); 590 return success(); 591 }))) 592 return failure(); 593 os << ") {\n"; 594 os.indent(); 595 if (emitter.shouldDeclareVariablesAtTop()) { 596 // Declare all variables that hold op results including those from nested 597 // regions. 598 WalkResult result = 599 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { 600 for (OpResult result : op->getResults()) { 601 if (failed(emitter.emitVariableDeclaration( 602 result, /*trailingSemicolon=*/true))) { 603 return WalkResult( 604 op->emitError("unable to declare result variable for op")); 605 } 606 } 607 return WalkResult::advance(); 608 }); 609 if (result.wasInterrupted()) 610 return failure(); 611 } 612 613 Region::BlockListType &blocks = functionOp.getBlocks(); 614 // Create label names for basic blocks. 615 for (Block &block : blocks) { 616 emitter.getOrCreateName(block); 617 } 618 619 // Declare variables for basic block arguments. 620 for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { 621 Block &block = *it; 622 for (BlockArgument &arg : block.getArguments()) { 623 if (emitter.hasValueInScope(arg)) 624 return functionOp.emitOpError(" block argument #") 625 << arg.getArgNumber() << " is out of scope"; 626 if (failed( 627 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { 628 return failure(); 629 } 630 os << " " << emitter.getOrCreateName(arg) << ";\n"; 631 } 632 } 633 634 for (Block &block : blocks) { 635 // Only print a label if the block has predecessors. 636 if (!block.hasNoPredecessors()) { 637 if (failed(emitter.emitLabel(block))) 638 return failure(); 639 } 640 for (Operation &op : block.getOperations()) { 641 // When generating code for an scf.if or cf.cond_br op no semicolon needs 642 // to be printed after the closing brace. 643 // When generating code for an scf.for op, printing a trailing semicolon 644 // is handled within the printOperation function. 645 bool trailingSemicolon = 646 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op); 647 648 if (failed(emitter.emitOperation( 649 op, /*trailingSemicolon=*/trailingSemicolon))) 650 return failure(); 651 } 652 } 653 os.unindent() << "}\n"; 654 return success(); 655 } 656 657 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) 658 : os(os), declareVariablesAtTop(declareVariablesAtTop) { 659 valueInScopeCount.push(0); 660 labelInScopeCount.push(0); 661 } 662 663 /// Return the existing or a new name for a Value. 664 StringRef CppEmitter::getOrCreateName(Value val) { 665 if (!valueMapper.count(val)) 666 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); 667 return *valueMapper.begin(val); 668 } 669 670 /// Return the existing or a new label for a Block. 671 StringRef CppEmitter::getOrCreateName(Block &block) { 672 if (!blockMapper.count(&block)) 673 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); 674 return *blockMapper.begin(&block); 675 } 676 677 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { 678 switch (val) { 679 case IntegerType::Signless: 680 return false; 681 case IntegerType::Signed: 682 return false; 683 case IntegerType::Unsigned: 684 return true; 685 } 686 llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); 687 } 688 689 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } 690 691 bool CppEmitter::hasBlockLabel(Block &block) { 692 return blockMapper.count(&block); 693 } 694 695 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { 696 auto printInt = [&](const APInt &val, bool isUnsigned) { 697 if (val.getBitWidth() == 1) { 698 if (val.getBoolValue()) 699 os << "true"; 700 else 701 os << "false"; 702 } else { 703 SmallString<128> strValue; 704 val.toString(strValue, 10, !isUnsigned, false); 705 os << strValue; 706 } 707 }; 708 709 auto printFloat = [&](const APFloat &val) { 710 if (val.isFinite()) { 711 SmallString<128> strValue; 712 // Use default values of toString except don't truncate zeros. 713 val.toString(strValue, 0, 0, false); 714 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { 715 case llvm::APFloatBase::S_IEEEsingle: 716 os << "(float)"; 717 break; 718 case llvm::APFloatBase::S_IEEEdouble: 719 os << "(double)"; 720 break; 721 default: 722 break; 723 }; 724 os << strValue; 725 } else if (val.isNaN()) { 726 os << "NAN"; 727 } else if (val.isInfinity()) { 728 if (val.isNegative()) 729 os << "-"; 730 os << "INFINITY"; 731 } 732 }; 733 734 // Print floating point attributes. 735 if (auto fAttr = attr.dyn_cast<FloatAttr>()) { 736 printFloat(fAttr.getValue()); 737 return success(); 738 } 739 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) { 740 os << '{'; 741 interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); 742 os << '}'; 743 return success(); 744 } 745 746 // Print integer attributes. 747 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) { 748 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) { 749 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); 750 return success(); 751 } 752 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) { 753 printInt(iAttr.getValue(), false); 754 return success(); 755 } 756 } 757 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) { 758 if (auto iType = dense.getType() 759 .cast<TensorType>() 760 .getElementType() 761 .dyn_cast<IntegerType>()) { 762 os << '{'; 763 interleaveComma(dense, os, [&](const APInt &val) { 764 printInt(val, shouldMapToUnsigned(iType.getSignedness())); 765 }); 766 os << '}'; 767 return success(); 768 } 769 if (auto iType = dense.getType() 770 .cast<TensorType>() 771 .getElementType() 772 .dyn_cast<IndexType>()) { 773 os << '{'; 774 interleaveComma(dense, os, 775 [&](const APInt &val) { printInt(val, false); }); 776 os << '}'; 777 return success(); 778 } 779 } 780 781 // Print opaque attributes. 782 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) { 783 os << oAttr.getValue(); 784 return success(); 785 } 786 787 // Print symbolic reference attributes. 788 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) { 789 if (sAttr.getNestedReferences().size() > 1) 790 return emitError(loc, "attribute has more than 1 nested reference"); 791 os << sAttr.getRootReference().getValue(); 792 return success(); 793 } 794 795 // Print type attributes. 796 if (auto type = attr.dyn_cast<TypeAttr>()) 797 return emitType(loc, type.getValue()); 798 799 return emitError(loc, "cannot emit attribute of type ") << attr.getType(); 800 } 801 802 LogicalResult CppEmitter::emitOperands(Operation &op) { 803 auto emitOperandName = [&](Value result) -> LogicalResult { 804 if (!hasValueInScope(result)) 805 return op.emitOpError() << "operand value not in scope"; 806 os << getOrCreateName(result); 807 return success(); 808 }; 809 return interleaveCommaWithError(op.getOperands(), os, emitOperandName); 810 } 811 812 LogicalResult 813 CppEmitter::emitOperandsAndAttributes(Operation &op, 814 ArrayRef<StringRef> exclude) { 815 if (failed(emitOperands(op))) 816 return failure(); 817 // Insert comma in between operands and non-filtered attributes if needed. 818 if (op.getNumOperands() > 0) { 819 for (NamedAttribute attr : op.getAttrs()) { 820 if (!llvm::is_contained(exclude, attr.getName().strref())) { 821 os << ", "; 822 break; 823 } 824 } 825 } 826 // Emit attributes. 827 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { 828 if (llvm::is_contained(exclude, attr.getName().strref())) 829 return success(); 830 os << "/* " << attr.getName().getValue() << " */"; 831 if (failed(emitAttribute(op.getLoc(), attr.getValue()))) 832 return failure(); 833 return success(); 834 }; 835 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); 836 } 837 838 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { 839 if (!hasValueInScope(result)) { 840 return result.getDefiningOp()->emitOpError( 841 "result variable for the operation has not been declared"); 842 } 843 os << getOrCreateName(result) << " = "; 844 return success(); 845 } 846 847 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, 848 bool trailingSemicolon) { 849 if (hasValueInScope(result)) { 850 return result.getDefiningOp()->emitError( 851 "result variable for the operation already declared"); 852 } 853 if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) 854 return failure(); 855 os << " " << getOrCreateName(result); 856 if (trailingSemicolon) 857 os << ";\n"; 858 return success(); 859 } 860 861 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { 862 switch (op.getNumResults()) { 863 case 0: 864 break; 865 case 1: { 866 OpResult result = op.getResult(0); 867 if (shouldDeclareVariablesAtTop()) { 868 if (failed(emitVariableAssignment(result))) 869 return failure(); 870 } else { 871 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) 872 return failure(); 873 os << " = "; 874 } 875 break; 876 } 877 default: 878 if (!shouldDeclareVariablesAtTop()) { 879 for (OpResult result : op.getResults()) { 880 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) 881 return failure(); 882 } 883 } 884 os << "std::tie("; 885 interleaveComma(op.getResults(), os, 886 [&](Value result) { os << getOrCreateName(result); }); 887 os << ") = "; 888 } 889 return success(); 890 } 891 892 LogicalResult CppEmitter::emitLabel(Block &block) { 893 if (!hasBlockLabel(block)) 894 return block.getParentOp()->emitError("label for block not found"); 895 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block 896 // label instead of using `getOStream`. 897 os.getOStream() << getOrCreateName(block) << ":\n"; 898 return success(); 899 } 900 901 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { 902 LogicalResult status = 903 llvm::TypeSwitch<Operation *, LogicalResult>(&op) 904 // EmitC ops. 905 .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp, 906 emitc::IncludeOp>( 907 [&](auto op) { return printOperation(*this, op); }) 908 // SCF ops. 909 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>( 910 [&](auto op) { return printOperation(*this, op); }) 911 // Standard ops. 912 .Case<cf::BranchOp, mlir::CallOp, cf::CondBranchOp, mlir::ConstantOp, 913 FuncOp, ModuleOp, ReturnOp>( 914 [&](auto op) { return printOperation(*this, op); }) 915 // Arithmetic ops. 916 .Case<arith::ConstantOp>( 917 [&](auto op) { return printOperation(*this, op); }) 918 .Default([&](Operation *) { 919 return op.emitOpError("unable to find printer for op"); 920 }); 921 922 if (failed(status)) 923 return failure(); 924 os << (trailingSemicolon ? ";\n" : "\n"); 925 return success(); 926 } 927 928 LogicalResult CppEmitter::emitType(Location loc, Type type) { 929 if (auto iType = type.dyn_cast<IntegerType>()) { 930 switch (iType.getWidth()) { 931 case 1: 932 return (os << "bool"), success(); 933 case 8: 934 case 16: 935 case 32: 936 case 64: 937 if (shouldMapToUnsigned(iType.getSignedness())) 938 return (os << "uint" << iType.getWidth() << "_t"), success(); 939 else 940 return (os << "int" << iType.getWidth() << "_t"), success(); 941 default: 942 return emitError(loc, "cannot emit integer type ") << type; 943 } 944 } 945 if (auto fType = type.dyn_cast<FloatType>()) { 946 switch (fType.getWidth()) { 947 case 32: 948 return (os << "float"), success(); 949 case 64: 950 return (os << "double"), success(); 951 default: 952 return emitError(loc, "cannot emit float type ") << type; 953 } 954 } 955 if (auto iType = type.dyn_cast<IndexType>()) 956 return (os << "size_t"), success(); 957 if (auto tType = type.dyn_cast<TensorType>()) { 958 if (!tType.hasRank()) 959 return emitError(loc, "cannot emit unranked tensor type"); 960 if (!tType.hasStaticShape()) 961 return emitError(loc, "cannot emit tensor type with non static shape"); 962 os << "Tensor<"; 963 if (failed(emitType(loc, tType.getElementType()))) 964 return failure(); 965 auto shape = tType.getShape(); 966 for (auto dimSize : shape) { 967 os << ", "; 968 os << dimSize; 969 } 970 os << ">"; 971 return success(); 972 } 973 if (auto tType = type.dyn_cast<TupleType>()) 974 return emitTupleType(loc, tType.getTypes()); 975 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) { 976 os << oType.getValue(); 977 return success(); 978 } 979 return emitError(loc, "cannot emit type ") << type; 980 } 981 982 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { 983 switch (types.size()) { 984 case 0: 985 os << "void"; 986 return success(); 987 case 1: 988 return emitType(loc, types.front()); 989 default: 990 return emitTupleType(loc, types); 991 } 992 } 993 994 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { 995 os << "std::tuple<"; 996 if (failed(interleaveCommaWithError( 997 types, os, [&](Type type) { return emitType(loc, type); }))) 998 return failure(); 999 os << ">"; 1000 return success(); 1001 } 1002 1003 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, 1004 bool declareVariablesAtTop) { 1005 CppEmitter emitter(os, declareVariablesAtTop); 1006 return emitter.emitOperation(*op, /*trailingSemicolon=*/false); 1007 } 1008