1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/EmitC/IR/EmitC.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Support/IndentedOstream.h" 19 #include "mlir/Target/Cpp/CppEmitter.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringMap.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 #define DEBUG_TYPE "translate-to-cpp" 28 29 using namespace mlir; 30 using namespace mlir::emitc; 31 using llvm::formatv; 32 33 /// Convenience functions to produce interleaved output with functions returning 34 /// a LogicalResult. This is different than those in STLExtras as functions used 35 /// on each element doesn't return a string. 36 template <typename ForwardIterator, typename UnaryFunctor, 37 typename NullaryFunctor> 38 inline LogicalResult 39 interleaveWithError(ForwardIterator begin, ForwardIterator end, 40 UnaryFunctor eachFn, NullaryFunctor betweenFn) { 41 if (begin == end) 42 return success(); 43 if (failed(eachFn(*begin))) 44 return failure(); 45 ++begin; 46 for (; begin != end; ++begin) { 47 betweenFn(); 48 if (failed(eachFn(*begin))) 49 return failure(); 50 } 51 return success(); 52 } 53 54 template <typename Container, typename UnaryFunctor, typename NullaryFunctor> 55 inline LogicalResult interleaveWithError(const Container &c, 56 UnaryFunctor eachFn, 57 NullaryFunctor betweenFn) { 58 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); 59 } 60 61 template <typename Container, typename UnaryFunctor> 62 inline LogicalResult interleaveCommaWithError(const Container &c, 63 raw_ostream &os, 64 UnaryFunctor eachFn) { 65 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); 66 } 67 68 namespace { 69 /// Emitter that uses dialect specific emitters to emit C++ code. 70 struct CppEmitter { 71 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); 72 73 /// Emits attribute or returns failure. 74 LogicalResult emitAttribute(Location loc, Attribute attr); 75 76 /// Emits operation 'op' with/without training semicolon or returns failure. 77 LogicalResult emitOperation(Operation &op, bool trailingSemicolon); 78 79 /// Emits type 'type' or returns failure. 80 LogicalResult emitType(Location loc, Type type); 81 82 /// Emits array of types as a std::tuple of the emitted types. 83 /// - emits void for an empty array; 84 /// - emits the type of the only element for arrays of size one; 85 /// - emits a std::tuple otherwise; 86 LogicalResult emitTypes(Location loc, ArrayRef<Type> types); 87 88 /// Emits array of types as a std::tuple of the emitted types independently of 89 /// the array size. 90 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); 91 92 /// Emits an assignment for a variable which has been declared previously. 93 LogicalResult emitVariableAssignment(OpResult result); 94 95 /// Emits a variable declaration for a result of an operation. 96 LogicalResult emitVariableDeclaration(OpResult result, 97 bool trailingSemicolon); 98 99 /// Emits the variable declaration and assignment prefix for 'op'. 100 /// - emits separate variable followed by std::tie for multi-valued operation; 101 /// - emits single type followed by variable for single result; 102 /// - emits nothing if no value produced by op; 103 /// Emits final '=' operator where a type is produced. Returns failure if 104 /// any result type could not be converted. 105 LogicalResult emitAssignPrefix(Operation &op); 106 107 /// Emits a label for the block. 108 LogicalResult emitLabel(Block &block); 109 110 /// Emits the operands and atttributes of the operation. All operands are 111 /// emitted first and then all attributes in alphabetical order. 112 LogicalResult emitOperandsAndAttributes(Operation &op, 113 ArrayRef<StringRef> exclude = {}); 114 115 /// Emits the operands of the operation. All operands are emitted in order. 116 LogicalResult emitOperands(Operation &op); 117 118 /// Return the existing or a new name for a Value. 119 StringRef getOrCreateName(Value val); 120 121 /// Return the existing or a new label of a Block. 122 StringRef getOrCreateName(Block &block); 123 124 /// Whether to map an mlir integer to a unsigned integer in C++. 125 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); 126 127 /// RAII helper function to manage entering/exiting C++ scopes. 128 struct Scope { 129 Scope(CppEmitter &emitter) 130 : valueMapperScope(emitter.valueMapper), 131 blockMapperScope(emitter.blockMapper), emitter(emitter) { 132 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); 133 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); 134 } 135 ~Scope() { 136 emitter.valueInScopeCount.pop(); 137 emitter.labelInScopeCount.pop(); 138 } 139 140 private: 141 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; 142 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; 143 CppEmitter &emitter; 144 }; 145 146 /// Returns wether the Value is assigned to a C++ variable in the scope. 147 bool hasValueInScope(Value val); 148 149 // Returns whether a label is assigned to the block. 150 bool hasBlockLabel(Block &block); 151 152 /// Returns the output stream. 153 raw_indented_ostream &ostream() { return os; }; 154 155 /// Returns if all variables for op results and basic block arguments need to 156 /// be declared at the beginning of a function. 157 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; 158 159 private: 160 using ValueMapper = llvm::ScopedHashTable<Value, std::string>; 161 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; 162 163 /// Output stream to emit to. 164 raw_indented_ostream os; 165 166 /// Boolean to enforce that all variables for op results and block 167 /// arguments are declared at the beginning of the function. This also 168 /// includes results from ops located in nested regions. 169 bool declareVariablesAtTop; 170 171 /// Map from value to name of C++ variable that contain the name. 172 ValueMapper valueMapper; 173 174 /// Map from block to name of C++ label. 175 BlockMapper blockMapper; 176 177 /// The number of values in the current scope. This is used to declare the 178 /// names of values in a scope. 179 std::stack<int64_t> valueInScopeCount; 180 std::stack<int64_t> labelInScopeCount; 181 }; 182 } // namespace 183 184 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, 185 Attribute value) { 186 OpResult result = operation->getResult(0); 187 188 // Only emit an assignment as the variable was already declared when printing 189 // the FuncOp. 190 if (emitter.shouldDeclareVariablesAtTop()) { 191 // Skip the assignment if the emitc.constant has no value. 192 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 193 if (oAttr.getValue().empty()) 194 return success(); 195 } 196 197 if (failed(emitter.emitVariableAssignment(result))) 198 return failure(); 199 return emitter.emitAttribute(operation->getLoc(), value); 200 } 201 202 // Emit a variable declaration for an emitc.constant op without value. 203 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 204 if (oAttr.getValue().empty()) 205 // The semicolon gets printed by the emitOperation function. 206 return emitter.emitVariableDeclaration(result, 207 /*trailingSemicolon=*/false); 208 } 209 210 // Emit a variable declaration. 211 if (failed(emitter.emitAssignPrefix(*operation))) 212 return failure(); 213 return emitter.emitAttribute(operation->getLoc(), value); 214 } 215 216 static LogicalResult printOperation(CppEmitter &emitter, 217 emitc::ConstantOp constantOp) { 218 Operation *operation = constantOp.getOperation(); 219 Attribute value = constantOp.value(); 220 221 return printConstantOp(emitter, operation, value); 222 } 223 224 static LogicalResult printOperation(CppEmitter &emitter, 225 arith::ConstantOp constantOp) { 226 Operation *operation = constantOp.getOperation(); 227 Attribute value = constantOp.getValue(); 228 229 return printConstantOp(emitter, operation, value); 230 } 231 232 static LogicalResult printOperation(CppEmitter &emitter, 233 mlir::ConstantOp constantOp) { 234 Operation *operation = constantOp.getOperation(); 235 Attribute value = constantOp.getValue(); 236 237 return printConstantOp(emitter, operation, value); 238 } 239 240 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) { 241 raw_ostream &os = emitter.ostream(); 242 Block &successor = *branchOp.getSuccessor(); 243 244 for (auto pair : 245 llvm::zip(branchOp.getOperands(), successor.getArguments())) { 246 Value &operand = std::get<0>(pair); 247 BlockArgument &argument = std::get<1>(pair); 248 os << emitter.getOrCreateName(argument) << " = " 249 << emitter.getOrCreateName(operand) << ";\n"; 250 } 251 252 os << "goto "; 253 if (!(emitter.hasBlockLabel(successor))) 254 return branchOp.emitOpError("unable to find label for successor block"); 255 os << emitter.getOrCreateName(successor); 256 return success(); 257 } 258 259 static LogicalResult printOperation(CppEmitter &emitter, 260 CondBranchOp condBranchOp) { 261 raw_indented_ostream &os = emitter.ostream(); 262 Block &trueSuccessor = *condBranchOp.getTrueDest(); 263 Block &falseSuccessor = *condBranchOp.getFalseDest(); 264 265 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) 266 << ") {\n"; 267 268 os.indent(); 269 270 // If condition is true. 271 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), 272 trueSuccessor.getArguments())) { 273 Value &operand = std::get<0>(pair); 274 BlockArgument &argument = std::get<1>(pair); 275 os << emitter.getOrCreateName(argument) << " = " 276 << emitter.getOrCreateName(operand) << ";\n"; 277 } 278 279 os << "goto "; 280 if (!(emitter.hasBlockLabel(trueSuccessor))) { 281 return condBranchOp.emitOpError("unable to find label for successor block"); 282 } 283 os << emitter.getOrCreateName(trueSuccessor) << ";\n"; 284 os.unindent() << "} else {\n"; 285 os.indent(); 286 // If condition is false. 287 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), 288 falseSuccessor.getArguments())) { 289 Value &operand = std::get<0>(pair); 290 BlockArgument &argument = std::get<1>(pair); 291 os << emitter.getOrCreateName(argument) << " = " 292 << emitter.getOrCreateName(operand) << ";\n"; 293 } 294 295 os << "goto "; 296 if (!(emitter.hasBlockLabel(falseSuccessor))) { 297 return condBranchOp.emitOpError() 298 << "unable to find label for successor block"; 299 } 300 os << emitter.getOrCreateName(falseSuccessor) << ";\n"; 301 os.unindent() << "}"; 302 return success(); 303 } 304 305 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) { 306 if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) 307 return failure(); 308 309 raw_ostream &os = emitter.ostream(); 310 os << callOp.getCallee() << "("; 311 if (failed(emitter.emitOperands(*callOp.getOperation()))) 312 return failure(); 313 os << ")"; 314 return success(); 315 } 316 317 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { 318 raw_ostream &os = emitter.ostream(); 319 Operation &op = *callOp.getOperation(); 320 321 if (failed(emitter.emitAssignPrefix(op))) 322 return failure(); 323 os << callOp.callee(); 324 325 auto emitArgs = [&](Attribute attr) -> LogicalResult { 326 if (auto t = attr.dyn_cast<IntegerAttr>()) { 327 // Index attributes are treated specially as operand index. 328 if (t.getType().isIndex()) { 329 int64_t idx = t.getInt(); 330 if ((idx < 0) || (idx >= op.getNumOperands())) 331 return op.emitOpError("invalid operand index"); 332 if (!emitter.hasValueInScope(op.getOperand(idx))) 333 return op.emitOpError("operand ") 334 << idx << "'s value not defined in scope"; 335 os << emitter.getOrCreateName(op.getOperand(idx)); 336 return success(); 337 } 338 } 339 if (failed(emitter.emitAttribute(op.getLoc(), attr))) 340 return failure(); 341 342 return success(); 343 }; 344 345 if (callOp.template_args()) { 346 os << "<"; 347 if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs))) 348 return failure(); 349 os << ">"; 350 } 351 352 os << "("; 353 354 LogicalResult emittedArgs = 355 callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs) 356 : emitter.emitOperands(op); 357 if (failed(emittedArgs)) 358 return failure(); 359 os << ")"; 360 return success(); 361 } 362 363 static LogicalResult printOperation(CppEmitter &emitter, 364 emitc::ApplyOp applyOp) { 365 raw_ostream &os = emitter.ostream(); 366 Operation &op = *applyOp.getOperation(); 367 368 if (failed(emitter.emitAssignPrefix(op))) 369 return failure(); 370 os << applyOp.applicableOperator(); 371 os << emitter.getOrCreateName(applyOp.getOperand()); 372 373 return success(); 374 } 375 376 static LogicalResult printOperation(CppEmitter &emitter, 377 emitc::IncludeOp includeOp) { 378 raw_ostream &os = emitter.ostream(); 379 380 os << "#include "; 381 if (includeOp.is_standard_include()) 382 os << "<" << includeOp.include() << ">"; 383 else 384 os << "\"" << includeOp.include() << "\""; 385 386 return success(); 387 } 388 389 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { 390 391 raw_indented_ostream &os = emitter.ostream(); 392 393 OperandRange operands = forOp.getIterOperands(); 394 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); 395 Operation::result_range results = forOp.getResults(); 396 397 if (!emitter.shouldDeclareVariablesAtTop()) { 398 for (OpResult result : results) { 399 if (failed(emitter.emitVariableDeclaration(result, 400 /*trailingSemicolon=*/true))) 401 return failure(); 402 } 403 } 404 405 for (auto pair : llvm::zip(iterArgs, operands)) { 406 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) 407 return failure(); 408 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; 409 os << emitter.getOrCreateName(std::get<1>(pair)) << ";"; 410 os << "\n"; 411 } 412 413 os << "for ("; 414 if (failed( 415 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) 416 return failure(); 417 os << " "; 418 os << emitter.getOrCreateName(forOp.getInductionVar()); 419 os << " = "; 420 os << emitter.getOrCreateName(forOp.getLowerBound()); 421 os << "; "; 422 os << emitter.getOrCreateName(forOp.getInductionVar()); 423 os << " < "; 424 os << emitter.getOrCreateName(forOp.getUpperBound()); 425 os << "; "; 426 os << emitter.getOrCreateName(forOp.getInductionVar()); 427 os << " += "; 428 os << emitter.getOrCreateName(forOp.getStep()); 429 os << ") {\n"; 430 os.indent(); 431 432 Region &forRegion = forOp.getRegion(); 433 auto regionOps = forRegion.getOps(); 434 435 // We skip the trailing yield op because this updates the result variables 436 // of the for op in the generated code. Instead we update the iterArgs at 437 // the end of a loop iteration and set the result variables after the for 438 // loop. 439 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { 440 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 441 return failure(); 442 } 443 444 Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); 445 // Copy yield operands into iterArgs at the end of a loop iteration. 446 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { 447 BlockArgument iterArg = std::get<0>(pair); 448 Value operand = std::get<1>(pair); 449 os << emitter.getOrCreateName(iterArg) << " = " 450 << emitter.getOrCreateName(operand) << ";\n"; 451 } 452 453 os.unindent() << "}"; 454 455 // Copy iterArgs into results after the for loop. 456 for (auto pair : llvm::zip(results, iterArgs)) { 457 OpResult result = std::get<0>(pair); 458 BlockArgument iterArg = std::get<1>(pair); 459 os << "\n" 460 << emitter.getOrCreateName(result) << " = " 461 << emitter.getOrCreateName(iterArg) << ";"; 462 } 463 464 return success(); 465 } 466 467 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { 468 raw_indented_ostream &os = emitter.ostream(); 469 470 if (!emitter.shouldDeclareVariablesAtTop()) { 471 for (OpResult result : ifOp.getResults()) { 472 if (failed(emitter.emitVariableDeclaration(result, 473 /*trailingSemicolon=*/true))) 474 return failure(); 475 } 476 } 477 478 os << "if ("; 479 if (failed(emitter.emitOperands(*ifOp.getOperation()))) 480 return failure(); 481 os << ") {\n"; 482 os.indent(); 483 484 Region &thenRegion = ifOp.getThenRegion(); 485 for (Operation &op : thenRegion.getOps()) { 486 // Note: This prints a superfluous semicolon if the terminating yield op has 487 // zero results. 488 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 489 return failure(); 490 } 491 492 os.unindent() << "}"; 493 494 Region &elseRegion = ifOp.getElseRegion(); 495 if (!elseRegion.empty()) { 496 os << " else {\n"; 497 os.indent(); 498 499 for (Operation &op : elseRegion.getOps()) { 500 // Note: This prints a superfluous semicolon if the terminating yield op 501 // has zero results. 502 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 503 return failure(); 504 } 505 506 os.unindent() << "}"; 507 } 508 509 return success(); 510 } 511 512 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { 513 raw_ostream &os = emitter.ostream(); 514 Operation &parentOp = *yieldOp.getOperation()->getParentOp(); 515 516 if (yieldOp.getNumOperands() != parentOp.getNumResults()) { 517 return yieldOp.emitError("number of operands does not to match the number " 518 "of the parent op's results"); 519 } 520 521 if (failed(interleaveWithError( 522 llvm::zip(parentOp.getResults(), yieldOp.getOperands()), 523 [&](auto pair) -> LogicalResult { 524 auto result = std::get<0>(pair); 525 auto operand = std::get<1>(pair); 526 os << emitter.getOrCreateName(result) << " = "; 527 528 if (!emitter.hasValueInScope(operand)) 529 return yieldOp.emitError("operand value not in scope"); 530 os << emitter.getOrCreateName(operand); 531 return success(); 532 }, 533 [&]() { os << ";\n"; }))) 534 return failure(); 535 536 return success(); 537 } 538 539 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) { 540 raw_ostream &os = emitter.ostream(); 541 os << "return"; 542 switch (returnOp.getNumOperands()) { 543 case 0: 544 return success(); 545 case 1: 546 os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); 547 return success(emitter.hasValueInScope(returnOp.getOperand(0))); 548 default: 549 os << " std::make_tuple("; 550 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) 551 return failure(); 552 os << ")"; 553 return success(); 554 } 555 } 556 557 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { 558 CppEmitter::Scope scope(emitter); 559 560 for (Operation &op : moduleOp) { 561 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) 562 return failure(); 563 } 564 return success(); 565 } 566 567 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) { 568 // We need to declare variables at top if the function has multiple blocks. 569 if (!emitter.shouldDeclareVariablesAtTop() && 570 functionOp.getBlocks().size() > 1) { 571 return functionOp.emitOpError( 572 "with multiple blocks needs variables declared at top"); 573 } 574 575 CppEmitter::Scope scope(emitter); 576 raw_indented_ostream &os = emitter.ostream(); 577 if (failed(emitter.emitTypes(functionOp.getLoc(), 578 functionOp.getType().getResults()))) 579 return failure(); 580 os << " " << functionOp.getName(); 581 582 os << "("; 583 if (failed(interleaveCommaWithError( 584 functionOp.getArguments(), os, 585 [&](BlockArgument arg) -> LogicalResult { 586 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) 587 return failure(); 588 os << " " << emitter.getOrCreateName(arg); 589 return success(); 590 }))) 591 return failure(); 592 os << ") {\n"; 593 os.indent(); 594 if (emitter.shouldDeclareVariablesAtTop()) { 595 // Declare all variables that hold op results including those from nested 596 // regions. 597 WalkResult result = 598 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { 599 for (OpResult result : op->getResults()) { 600 if (failed(emitter.emitVariableDeclaration( 601 result, /*trailingSemicolon=*/true))) { 602 return WalkResult( 603 op->emitError("unable to declare result variable for op")); 604 } 605 } 606 return WalkResult::advance(); 607 }); 608 if (result.wasInterrupted()) 609 return failure(); 610 } 611 612 Region::BlockListType &blocks = functionOp.getBlocks(); 613 // Create label names for basic blocks. 614 for (Block &block : blocks) { 615 emitter.getOrCreateName(block); 616 } 617 618 // Declare variables for basic block arguments. 619 for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { 620 Block &block = *it; 621 for (BlockArgument &arg : block.getArguments()) { 622 if (emitter.hasValueInScope(arg)) 623 return functionOp.emitOpError(" block argument #") 624 << arg.getArgNumber() << " is out of scope"; 625 if (failed( 626 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { 627 return failure(); 628 } 629 os << " " << emitter.getOrCreateName(arg) << ";\n"; 630 } 631 } 632 633 for (Block &block : blocks) { 634 // Only print a label if there is more than one block. 635 if (blocks.size() > 1) { 636 if (failed(emitter.emitLabel(block))) 637 return failure(); 638 } 639 for (Operation &op : block.getOperations()) { 640 // When generating code for an scf.if or std.cond_br op no semicolon needs 641 // to be printed after the closing brace. 642 // When generating code for an scf.for op, printing a trailing semicolon 643 // is handled within the printOperation function. 644 bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op); 645 646 if (failed(emitter.emitOperation( 647 op, /*trailingSemicolon=*/trailingSemicolon))) 648 return failure(); 649 } 650 } 651 os.unindent() << "}\n"; 652 return success(); 653 } 654 655 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) 656 : os(os), declareVariablesAtTop(declareVariablesAtTop) { 657 valueInScopeCount.push(0); 658 labelInScopeCount.push(0); 659 } 660 661 /// Return the existing or a new name for a Value. 662 StringRef CppEmitter::getOrCreateName(Value val) { 663 if (!valueMapper.count(val)) 664 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); 665 return *valueMapper.begin(val); 666 } 667 668 /// Return the existing or a new label for a Block. 669 StringRef CppEmitter::getOrCreateName(Block &block) { 670 if (!blockMapper.count(&block)) 671 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); 672 return *blockMapper.begin(&block); 673 } 674 675 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { 676 switch (val) { 677 case IntegerType::Signless: 678 return false; 679 case IntegerType::Signed: 680 return false; 681 case IntegerType::Unsigned: 682 return true; 683 } 684 llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); 685 } 686 687 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } 688 689 bool CppEmitter::hasBlockLabel(Block &block) { 690 return blockMapper.count(&block); 691 } 692 693 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { 694 auto printInt = [&](const APInt &val, bool isUnsigned) { 695 if (val.getBitWidth() == 1) { 696 if (val.getBoolValue()) 697 os << "true"; 698 else 699 os << "false"; 700 } else { 701 SmallString<128> strValue; 702 val.toString(strValue, 10, !isUnsigned, false); 703 os << strValue; 704 } 705 }; 706 707 auto printFloat = [&](const APFloat &val) { 708 if (val.isFinite()) { 709 SmallString<128> strValue; 710 // Use default values of toString except don't truncate zeros. 711 val.toString(strValue, 0, 0, false); 712 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { 713 case llvm::APFloatBase::S_IEEEsingle: 714 os << "(float)"; 715 break; 716 case llvm::APFloatBase::S_IEEEdouble: 717 os << "(double)"; 718 break; 719 default: 720 break; 721 }; 722 os << strValue; 723 } else if (val.isNaN()) { 724 os << "NAN"; 725 } else if (val.isInfinity()) { 726 if (val.isNegative()) 727 os << "-"; 728 os << "INFINITY"; 729 } 730 }; 731 732 // Print floating point attributes. 733 if (auto fAttr = attr.dyn_cast<FloatAttr>()) { 734 printFloat(fAttr.getValue()); 735 return success(); 736 } 737 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) { 738 os << '{'; 739 interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); 740 os << '}'; 741 return success(); 742 } 743 744 // Print integer attributes. 745 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) { 746 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) { 747 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); 748 return success(); 749 } 750 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) { 751 printInt(iAttr.getValue(), false); 752 return success(); 753 } 754 } 755 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) { 756 if (auto iType = dense.getType() 757 .cast<TensorType>() 758 .getElementType() 759 .dyn_cast<IntegerType>()) { 760 os << '{'; 761 interleaveComma(dense, os, [&](const APInt &val) { 762 printInt(val, shouldMapToUnsigned(iType.getSignedness())); 763 }); 764 os << '}'; 765 return success(); 766 } 767 if (auto iType = dense.getType() 768 .cast<TensorType>() 769 .getElementType() 770 .dyn_cast<IndexType>()) { 771 os << '{'; 772 interleaveComma(dense, os, 773 [&](const APInt &val) { printInt(val, false); }); 774 os << '}'; 775 return success(); 776 } 777 } 778 779 // Print opaque attributes. 780 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) { 781 os << oAttr.getValue(); 782 return success(); 783 } 784 785 // Print symbolic reference attributes. 786 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) { 787 if (sAttr.getNestedReferences().size() > 1) 788 return emitError(loc, "attribute has more than 1 nested reference"); 789 os << sAttr.getRootReference().getValue(); 790 return success(); 791 } 792 793 // Print type attributes. 794 if (auto type = attr.dyn_cast<TypeAttr>()) 795 return emitType(loc, type.getValue()); 796 797 return emitError(loc, "cannot emit attribute of type ") << attr.getType(); 798 } 799 800 LogicalResult CppEmitter::emitOperands(Operation &op) { 801 auto emitOperandName = [&](Value result) -> LogicalResult { 802 if (!hasValueInScope(result)) 803 return op.emitOpError() << "operand value not in scope"; 804 os << getOrCreateName(result); 805 return success(); 806 }; 807 return interleaveCommaWithError(op.getOperands(), os, emitOperandName); 808 } 809 810 LogicalResult 811 CppEmitter::emitOperandsAndAttributes(Operation &op, 812 ArrayRef<StringRef> exclude) { 813 if (failed(emitOperands(op))) 814 return failure(); 815 // Insert comma in between operands and non-filtered attributes if needed. 816 if (op.getNumOperands() > 0) { 817 for (NamedAttribute attr : op.getAttrs()) { 818 if (!llvm::is_contained(exclude, attr.getName().strref())) { 819 os << ", "; 820 break; 821 } 822 } 823 } 824 // Emit attributes. 825 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { 826 if (llvm::is_contained(exclude, attr.getName().strref())) 827 return success(); 828 os << "/* " << attr.getName().getValue() << " */"; 829 if (failed(emitAttribute(op.getLoc(), attr.getValue()))) 830 return failure(); 831 return success(); 832 }; 833 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); 834 } 835 836 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { 837 if (!hasValueInScope(result)) { 838 return result.getDefiningOp()->emitOpError( 839 "result variable for the operation has not been declared"); 840 } 841 os << getOrCreateName(result) << " = "; 842 return success(); 843 } 844 845 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, 846 bool trailingSemicolon) { 847 if (hasValueInScope(result)) { 848 return result.getDefiningOp()->emitError( 849 "result variable for the operation already declared"); 850 } 851 if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) 852 return failure(); 853 os << " " << getOrCreateName(result); 854 if (trailingSemicolon) 855 os << ";\n"; 856 return success(); 857 } 858 859 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { 860 switch (op.getNumResults()) { 861 case 0: 862 break; 863 case 1: { 864 OpResult result = op.getResult(0); 865 if (shouldDeclareVariablesAtTop()) { 866 if (failed(emitVariableAssignment(result))) 867 return failure(); 868 } else { 869 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) 870 return failure(); 871 os << " = "; 872 } 873 break; 874 } 875 default: 876 if (!shouldDeclareVariablesAtTop()) { 877 for (OpResult result : op.getResults()) { 878 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) 879 return failure(); 880 } 881 } 882 os << "std::tie("; 883 interleaveComma(op.getResults(), os, 884 [&](Value result) { os << getOrCreateName(result); }); 885 os << ") = "; 886 } 887 return success(); 888 } 889 890 LogicalResult CppEmitter::emitLabel(Block &block) { 891 if (!hasBlockLabel(block)) 892 return block.getParentOp()->emitError("label for block not found"); 893 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block 894 // label instead of using `getOStream`. 895 os.getOStream() << getOrCreateName(block) << ":\n"; 896 return success(); 897 } 898 899 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { 900 LogicalResult status = 901 llvm::TypeSwitch<Operation *, LogicalResult>(&op) 902 // EmitC ops. 903 .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp, 904 emitc::IncludeOp>( 905 [&](auto op) { return printOperation(*this, op); }) 906 // SCF ops. 907 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>( 908 [&](auto op) { return printOperation(*this, op); }) 909 // Standard ops. 910 .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp, 911 ModuleOp, ReturnOp>( 912 [&](auto op) { return printOperation(*this, op); }) 913 // Arithmetic ops. 914 .Case<arith::ConstantOp>( 915 [&](auto op) { return printOperation(*this, op); }) 916 .Default([&](Operation *) { 917 return op.emitOpError("unable to find printer for op"); 918 }); 919 920 if (failed(status)) 921 return failure(); 922 os << (trailingSemicolon ? ";\n" : "\n"); 923 return success(); 924 } 925 926 LogicalResult CppEmitter::emitType(Location loc, Type type) { 927 if (auto iType = type.dyn_cast<IntegerType>()) { 928 switch (iType.getWidth()) { 929 case 1: 930 return (os << "bool"), success(); 931 case 8: 932 case 16: 933 case 32: 934 case 64: 935 if (shouldMapToUnsigned(iType.getSignedness())) 936 return (os << "uint" << iType.getWidth() << "_t"), success(); 937 else 938 return (os << "int" << iType.getWidth() << "_t"), success(); 939 default: 940 return emitError(loc, "cannot emit integer type ") << type; 941 } 942 } 943 if (auto fType = type.dyn_cast<FloatType>()) { 944 switch (fType.getWidth()) { 945 case 32: 946 return (os << "float"), success(); 947 case 64: 948 return (os << "double"), success(); 949 default: 950 return emitError(loc, "cannot emit float type ") << type; 951 } 952 } 953 if (auto iType = type.dyn_cast<IndexType>()) 954 return (os << "size_t"), success(); 955 if (auto tType = type.dyn_cast<TensorType>()) { 956 if (!tType.hasRank()) 957 return emitError(loc, "cannot emit unranked tensor type"); 958 if (!tType.hasStaticShape()) 959 return emitError(loc, "cannot emit tensor type with non static shape"); 960 os << "Tensor<"; 961 if (failed(emitType(loc, tType.getElementType()))) 962 return failure(); 963 auto shape = tType.getShape(); 964 for (auto dimSize : shape) { 965 os << ", "; 966 os << dimSize; 967 } 968 os << ">"; 969 return success(); 970 } 971 if (auto tType = type.dyn_cast<TupleType>()) 972 return emitTupleType(loc, tType.getTypes()); 973 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) { 974 os << oType.getValue(); 975 return success(); 976 } 977 return emitError(loc, "cannot emit type ") << type; 978 } 979 980 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { 981 switch (types.size()) { 982 case 0: 983 os << "void"; 984 return success(); 985 case 1: 986 return emitType(loc, types.front()); 987 default: 988 return emitTupleType(loc, types); 989 } 990 } 991 992 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { 993 os << "std::tuple<"; 994 if (failed(interleaveCommaWithError( 995 types, os, [&](Type type) { return emitType(loc, type); }))) 996 return failure(); 997 os << ">"; 998 return success(); 999 } 1000 1001 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, 1002 bool declareVariablesAtTop) { 1003 CppEmitter emitter(os, declareVariablesAtTop); 1004 return emitter.emitOperation(*op, /*trailingSemicolon=*/false); 1005 } 1006