1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/EmitC/IR/EmitC.h" 10 #include "mlir/Dialect/SCF/SCF.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Support/IndentedOstream.h" 17 #include "mlir/Target/Cpp/CppEmitter.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/ADT/StringMap.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/FormatVariadic.h" 24 25 #define DEBUG_TYPE "translate-to-cpp" 26 27 using namespace mlir; 28 using namespace mlir::emitc; 29 using llvm::formatv; 30 31 /// Convenience functions to produce interleaved output with functions returning 32 /// a LogicalResult. This is different than those in STLExtras as functions used 33 /// on each element doesn't return a string. 34 template <typename ForwardIterator, typename UnaryFunctor, 35 typename NullaryFunctor> 36 inline LogicalResult 37 interleaveWithError(ForwardIterator begin, ForwardIterator end, 38 UnaryFunctor eachFn, NullaryFunctor betweenFn) { 39 if (begin == end) 40 return success(); 41 if (failed(eachFn(*begin))) 42 return failure(); 43 ++begin; 44 for (; begin != end; ++begin) { 45 betweenFn(); 46 if (failed(eachFn(*begin))) 47 return failure(); 48 } 49 return success(); 50 } 51 52 template <typename Container, typename UnaryFunctor, typename NullaryFunctor> 53 inline LogicalResult interleaveWithError(const Container &c, 54 UnaryFunctor eachFn, 55 NullaryFunctor betweenFn) { 56 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); 57 } 58 59 template <typename Container, typename UnaryFunctor> 60 inline LogicalResult interleaveCommaWithError(const Container &c, 61 raw_ostream &os, 62 UnaryFunctor eachFn) { 63 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); 64 } 65 66 namespace { 67 /// Emitter that uses dialect specific emitters to emit C++ code. 68 struct CppEmitter { 69 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); 70 71 /// Emits attribute or returns failure. 72 LogicalResult emitAttribute(Location loc, Attribute attr); 73 74 /// Emits operation 'op' with/without training semicolon or returns failure. 75 LogicalResult emitOperation(Operation &op, bool trailingSemicolon); 76 77 /// Emits type 'type' or returns failure. 78 LogicalResult emitType(Location loc, Type type); 79 80 /// Emits array of types as a std::tuple of the emitted types. 81 /// - emits void for an empty array; 82 /// - emits the type of the only element for arrays of size one; 83 /// - emits a std::tuple otherwise; 84 LogicalResult emitTypes(Location loc, ArrayRef<Type> types); 85 86 /// Emits array of types as a std::tuple of the emitted types independently of 87 /// the array size. 88 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); 89 90 /// Emits an assignment for a variable which has been declared previously. 91 LogicalResult emitVariableAssignment(OpResult result); 92 93 /// Emits a variable declaration for a result of an operation. 94 LogicalResult emitVariableDeclaration(OpResult result, 95 bool trailingSemicolon); 96 97 /// Emits the variable declaration and assignment prefix for 'op'. 98 /// - emits separate variable followed by std::tie for multi-valued operation; 99 /// - emits single type followed by variable for single result; 100 /// - emits nothing if no value produced by op; 101 /// Emits final '=' operator where a type is produced. Returns failure if 102 /// any result type could not be converted. 103 LogicalResult emitAssignPrefix(Operation &op); 104 105 /// Emits a label for the block. 106 LogicalResult emitLabel(Block &block); 107 108 /// Emits the operands and atttributes of the operation. All operands are 109 /// emitted first and then all attributes in alphabetical order. 110 LogicalResult emitOperandsAndAttributes(Operation &op, 111 ArrayRef<StringRef> exclude = {}); 112 113 /// Emits the operands of the operation. All operands are emitted in order. 114 LogicalResult emitOperands(Operation &op); 115 116 /// Return the existing or a new name for a Value. 117 StringRef getOrCreateName(Value val); 118 119 /// Return the existing or a new label of a Block. 120 StringRef getOrCreateName(Block &block); 121 122 /// Whether to map an mlir integer to a unsigned integer in C++. 123 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); 124 125 /// RAII helper function to manage entering/exiting C++ scopes. 126 struct Scope { 127 Scope(CppEmitter &emitter) 128 : valueMapperScope(emitter.valueMapper), 129 blockMapperScope(emitter.blockMapper), emitter(emitter) { 130 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); 131 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); 132 } 133 ~Scope() { 134 emitter.valueInScopeCount.pop(); 135 emitter.labelInScopeCount.pop(); 136 } 137 138 private: 139 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; 140 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; 141 CppEmitter &emitter; 142 }; 143 144 /// Returns wether the Value is assigned to a C++ variable in the scope. 145 bool hasValueInScope(Value val); 146 147 // Returns whether a label is assigned to the block. 148 bool hasBlockLabel(Block &block); 149 150 /// Returns the output stream. 151 raw_indented_ostream &ostream() { return os; }; 152 153 /// Returns if all variables for op results and basic block arguments need to 154 /// be declared at the beginning of a function. 155 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; 156 157 private: 158 using ValueMapper = llvm::ScopedHashTable<Value, std::string>; 159 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; 160 161 /// Output stream to emit to. 162 raw_indented_ostream os; 163 164 /// Boolean to enforce that all variables for op results and block 165 /// arguments are declared at the beginning of the function. This also 166 /// includes results from ops located in nested regions. 167 bool declareVariablesAtTop; 168 169 /// Map from value to name of C++ variable that contain the name. 170 ValueMapper valueMapper; 171 172 /// Map from block to name of C++ label. 173 BlockMapper blockMapper; 174 175 /// The number of values in the current scope. This is used to declare the 176 /// names of values in a scope. 177 std::stack<int64_t> valueInScopeCount; 178 std::stack<int64_t> labelInScopeCount; 179 }; 180 } // namespace 181 182 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, 183 Attribute value) { 184 OpResult result = operation->getResult(0); 185 186 // Only emit an assignment as the variable was already declared when printing 187 // the FuncOp. 188 if (emitter.shouldDeclareVariablesAtTop()) { 189 // Skip the assignment if the emitc.constant has no value. 190 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 191 if (oAttr.getValue().empty()) 192 return success(); 193 } 194 195 if (failed(emitter.emitVariableAssignment(result))) 196 return failure(); 197 return emitter.emitAttribute(operation->getLoc(), value); 198 } 199 200 // Emit a variable declaration for an emitc.constant op without value. 201 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 202 if (oAttr.getValue().empty()) 203 // The semicolon gets printed by the emitOperation function. 204 return emitter.emitVariableDeclaration(result, 205 /*trailingSemicolon=*/false); 206 } 207 208 // Emit a variable declaration. 209 if (failed(emitter.emitAssignPrefix(*operation))) 210 return failure(); 211 return emitter.emitAttribute(operation->getLoc(), value); 212 } 213 214 static LogicalResult printOperation(CppEmitter &emitter, 215 emitc::ConstantOp constantOp) { 216 Operation *operation = constantOp.getOperation(); 217 Attribute value = constantOp.value(); 218 219 return printConstantOp(emitter, operation, value); 220 } 221 222 static LogicalResult printOperation(CppEmitter &emitter, 223 mlir::ConstantOp constantOp) { 224 Operation *operation = constantOp.getOperation(); 225 Attribute value = constantOp.value(); 226 227 return printConstantOp(emitter, operation, value); 228 } 229 230 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) { 231 raw_ostream &os = emitter.ostream(); 232 Block &successor = *branchOp.getSuccessor(); 233 234 for (auto pair : 235 llvm::zip(branchOp.getOperands(), successor.getArguments())) { 236 Value &operand = std::get<0>(pair); 237 BlockArgument &argument = std::get<1>(pair); 238 os << emitter.getOrCreateName(argument) << " = " 239 << emitter.getOrCreateName(operand) << ";\n"; 240 } 241 242 os << "goto "; 243 if (!(emitter.hasBlockLabel(successor))) 244 return branchOp.emitOpError("unable to find label for successor block"); 245 os << emitter.getOrCreateName(successor); 246 return success(); 247 } 248 249 static LogicalResult printOperation(CppEmitter &emitter, 250 CondBranchOp condBranchOp) { 251 raw_indented_ostream &os = emitter.ostream(); 252 Block &trueSuccessor = *condBranchOp.getTrueDest(); 253 Block &falseSuccessor = *condBranchOp.getFalseDest(); 254 255 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) 256 << ") {\n"; 257 258 os.indent(); 259 260 // If condition is true. 261 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), 262 trueSuccessor.getArguments())) { 263 Value &operand = std::get<0>(pair); 264 BlockArgument &argument = std::get<1>(pair); 265 os << emitter.getOrCreateName(argument) << " = " 266 << emitter.getOrCreateName(operand) << ";\n"; 267 } 268 269 os << "goto "; 270 if (!(emitter.hasBlockLabel(trueSuccessor))) { 271 return condBranchOp.emitOpError("unable to find label for successor block"); 272 } 273 os << emitter.getOrCreateName(trueSuccessor) << ";\n"; 274 os.unindent() << "} else {\n"; 275 os.indent(); 276 // If condition is false. 277 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), 278 falseSuccessor.getArguments())) { 279 Value &operand = std::get<0>(pair); 280 BlockArgument &argument = std::get<1>(pair); 281 os << emitter.getOrCreateName(argument) << " = " 282 << emitter.getOrCreateName(operand) << ";\n"; 283 } 284 285 os << "goto "; 286 if (!(emitter.hasBlockLabel(falseSuccessor))) { 287 return condBranchOp.emitOpError() 288 << "unable to find label for successor block"; 289 } 290 os << emitter.getOrCreateName(falseSuccessor) << ";\n"; 291 os.unindent() << "}"; 292 return success(); 293 } 294 295 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) { 296 if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) 297 return failure(); 298 299 raw_ostream &os = emitter.ostream(); 300 os << callOp.getCallee() << "("; 301 if (failed(emitter.emitOperands(*callOp.getOperation()))) 302 return failure(); 303 os << ")"; 304 return success(); 305 } 306 307 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { 308 raw_ostream &os = emitter.ostream(); 309 Operation &op = *callOp.getOperation(); 310 311 if (failed(emitter.emitAssignPrefix(op))) 312 return failure(); 313 os << callOp.callee(); 314 315 auto emitArgs = [&](Attribute attr) -> LogicalResult { 316 if (auto t = attr.dyn_cast<IntegerAttr>()) { 317 // Index attributes are treated specially as operand index. 318 if (t.getType().isIndex()) { 319 int64_t idx = t.getInt(); 320 if ((idx < 0) || (idx >= op.getNumOperands())) 321 return op.emitOpError("invalid operand index"); 322 if (!emitter.hasValueInScope(op.getOperand(idx))) 323 return op.emitOpError("operand ") 324 << idx << "'s value not defined in scope"; 325 os << emitter.getOrCreateName(op.getOperand(idx)); 326 return success(); 327 } 328 } 329 if (failed(emitter.emitAttribute(op.getLoc(), attr))) 330 return failure(); 331 332 return success(); 333 }; 334 335 if (callOp.template_args()) { 336 os << "<"; 337 if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs))) 338 return failure(); 339 os << ">"; 340 } 341 342 os << "("; 343 344 LogicalResult emittedArgs = 345 callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs) 346 : emitter.emitOperands(op); 347 if (failed(emittedArgs)) 348 return failure(); 349 os << ")"; 350 return success(); 351 } 352 353 static LogicalResult printOperation(CppEmitter &emitter, 354 emitc::ApplyOp applyOp) { 355 raw_ostream &os = emitter.ostream(); 356 Operation &op = *applyOp.getOperation(); 357 358 if (failed(emitter.emitAssignPrefix(op))) 359 return failure(); 360 os << applyOp.applicableOperator(); 361 os << emitter.getOrCreateName(applyOp.getOperand()); 362 363 return success(); 364 } 365 366 static LogicalResult printOperation(CppEmitter &emitter, 367 emitc::IncludeOp includeOp) { 368 raw_ostream &os = emitter.ostream(); 369 370 os << "#include "; 371 if (includeOp.is_standard_include()) 372 os << "<" << includeOp.include() << ">"; 373 else 374 os << "\"" << includeOp.include() << "\""; 375 376 return success(); 377 } 378 379 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { 380 381 raw_indented_ostream &os = emitter.ostream(); 382 383 OperandRange operands = forOp.getIterOperands(); 384 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); 385 Operation::result_range results = forOp.getResults(); 386 387 if (!emitter.shouldDeclareVariablesAtTop()) { 388 for (OpResult result : results) { 389 if (failed(emitter.emitVariableDeclaration(result, 390 /*trailingSemicolon=*/true))) 391 return failure(); 392 } 393 } 394 395 for (auto pair : llvm::zip(iterArgs, operands)) { 396 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) 397 return failure(); 398 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; 399 os << emitter.getOrCreateName(std::get<1>(pair)) << ";"; 400 os << "\n"; 401 } 402 403 os << "for ("; 404 if (failed( 405 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) 406 return failure(); 407 os << " "; 408 os << emitter.getOrCreateName(forOp.getInductionVar()); 409 os << " = "; 410 os << emitter.getOrCreateName(forOp.lowerBound()); 411 os << "; "; 412 os << emitter.getOrCreateName(forOp.getInductionVar()); 413 os << " < "; 414 os << emitter.getOrCreateName(forOp.upperBound()); 415 os << "; "; 416 os << emitter.getOrCreateName(forOp.getInductionVar()); 417 os << " += "; 418 os << emitter.getOrCreateName(forOp.step()); 419 os << ") {\n"; 420 os.indent(); 421 422 Region &forRegion = forOp.region(); 423 auto regionOps = forRegion.getOps(); 424 425 // We skip the trailing yield op because this updates the result variables 426 // of the for op in the generated code. Instead we update the iterArgs at 427 // the end of a loop iteration and set the result variables after the for 428 // loop. 429 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { 430 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 431 return failure(); 432 } 433 434 Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); 435 // Copy yield operands into iterArgs at the end of a loop iteration. 436 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { 437 BlockArgument iterArg = std::get<0>(pair); 438 Value operand = std::get<1>(pair); 439 os << emitter.getOrCreateName(iterArg) << " = " 440 << emitter.getOrCreateName(operand) << ";\n"; 441 } 442 443 os.unindent() << "}"; 444 445 // Copy iterArgs into results after the for loop. 446 for (auto pair : llvm::zip(results, iterArgs)) { 447 OpResult result = std::get<0>(pair); 448 BlockArgument iterArg = std::get<1>(pair); 449 os << "\n" 450 << emitter.getOrCreateName(result) << " = " 451 << emitter.getOrCreateName(iterArg) << ";"; 452 } 453 454 return success(); 455 } 456 457 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { 458 raw_indented_ostream &os = emitter.ostream(); 459 460 if (!emitter.shouldDeclareVariablesAtTop()) { 461 for (OpResult result : ifOp.getResults()) { 462 if (failed(emitter.emitVariableDeclaration(result, 463 /*trailingSemicolon=*/true))) 464 return failure(); 465 } 466 } 467 468 os << "if ("; 469 if (failed(emitter.emitOperands(*ifOp.getOperation()))) 470 return failure(); 471 os << ") {\n"; 472 os.indent(); 473 474 Region &thenRegion = ifOp.thenRegion(); 475 for (Operation &op : thenRegion.getOps()) { 476 // Note: This prints a superfluous semicolon if the terminating yield op has 477 // zero results. 478 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 479 return failure(); 480 } 481 482 os.unindent() << "}"; 483 484 Region &elseRegion = ifOp.elseRegion(); 485 if (!elseRegion.empty()) { 486 os << " else {\n"; 487 os.indent(); 488 489 for (Operation &op : elseRegion.getOps()) { 490 // Note: This prints a superfluous semicolon if the terminating yield op 491 // has zero results. 492 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 493 return failure(); 494 } 495 496 os.unindent() << "}"; 497 } 498 499 return success(); 500 } 501 502 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { 503 raw_ostream &os = emitter.ostream(); 504 Operation &parentOp = *yieldOp.getOperation()->getParentOp(); 505 506 if (yieldOp.getNumOperands() != parentOp.getNumResults()) { 507 return yieldOp.emitError("number of operands does not to match the number " 508 "of the parent op's results"); 509 } 510 511 if (failed(interleaveWithError( 512 llvm::zip(parentOp.getResults(), yieldOp.getOperands()), 513 [&](auto pair) -> LogicalResult { 514 auto result = std::get<0>(pair); 515 auto operand = std::get<1>(pair); 516 os << emitter.getOrCreateName(result) << " = "; 517 518 if (!emitter.hasValueInScope(operand)) 519 return yieldOp.emitError("operand value not in scope"); 520 os << emitter.getOrCreateName(operand); 521 return success(); 522 }, 523 [&]() { os << ";\n"; }))) 524 return failure(); 525 526 return success(); 527 } 528 529 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) { 530 raw_ostream &os = emitter.ostream(); 531 os << "return"; 532 switch (returnOp.getNumOperands()) { 533 case 0: 534 return success(); 535 case 1: 536 os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); 537 return success(emitter.hasValueInScope(returnOp.getOperand(0))); 538 default: 539 os << " std::make_tuple("; 540 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) 541 return failure(); 542 os << ")"; 543 return success(); 544 } 545 } 546 547 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { 548 CppEmitter::Scope scope(emitter); 549 550 for (Operation &op : moduleOp) { 551 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) 552 return failure(); 553 } 554 return success(); 555 } 556 557 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) { 558 // We need to declare variables at top if the function has multiple blocks. 559 if (!emitter.shouldDeclareVariablesAtTop() && 560 functionOp.getBlocks().size() > 1) { 561 return functionOp.emitOpError( 562 "with multiple blocks needs variables declared at top"); 563 } 564 565 CppEmitter::Scope scope(emitter); 566 raw_indented_ostream &os = emitter.ostream(); 567 if (failed(emitter.emitTypes(functionOp.getLoc(), 568 functionOp.getType().getResults()))) 569 return failure(); 570 os << " " << functionOp.getName(); 571 572 os << "("; 573 if (failed(interleaveCommaWithError( 574 functionOp.getArguments(), os, 575 [&](BlockArgument arg) -> LogicalResult { 576 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) 577 return failure(); 578 os << " " << emitter.getOrCreateName(arg); 579 return success(); 580 }))) 581 return failure(); 582 os << ") {\n"; 583 os.indent(); 584 if (emitter.shouldDeclareVariablesAtTop()) { 585 // Declare all variables that hold op results including those from nested 586 // regions. 587 WalkResult result = 588 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { 589 for (OpResult result : op->getResults()) { 590 if (failed(emitter.emitVariableDeclaration( 591 result, /*trailingSemicolon=*/true))) { 592 return WalkResult( 593 op->emitError("unable to declare result variable for op")); 594 } 595 } 596 return WalkResult::advance(); 597 }); 598 if (result.wasInterrupted()) 599 return failure(); 600 } 601 602 Region::BlockListType &blocks = functionOp.getBlocks(); 603 // Create label names for basic blocks. 604 for (Block &block : blocks) { 605 emitter.getOrCreateName(block); 606 } 607 608 // Declare variables for basic block arguments. 609 for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { 610 Block &block = *it; 611 for (BlockArgument &arg : block.getArguments()) { 612 if (emitter.hasValueInScope(arg)) 613 return functionOp.emitOpError(" block argument #") 614 << arg.getArgNumber() << " is out of scope"; 615 if (failed( 616 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { 617 return failure(); 618 } 619 os << " " << emitter.getOrCreateName(arg) << ";\n"; 620 } 621 } 622 623 for (Block &block : blocks) { 624 // Only print a label if there is more than one block. 625 if (blocks.size() > 1) { 626 if (failed(emitter.emitLabel(block))) 627 return failure(); 628 } 629 for (Operation &op : block.getOperations()) { 630 // When generating code for an scf.if or std.cond_br op no semicolon needs 631 // to be printed after the closing brace. 632 // When generating code for an scf.for op, printing a trailing semicolon 633 // is handled within the printOperation function. 634 bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op); 635 636 if (failed(emitter.emitOperation( 637 op, /*trailingSemicolon=*/trailingSemicolon))) 638 return failure(); 639 } 640 } 641 os.unindent() << "}\n"; 642 return success(); 643 } 644 645 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) 646 : os(os), declareVariablesAtTop(declareVariablesAtTop) { 647 valueInScopeCount.push(0); 648 labelInScopeCount.push(0); 649 } 650 651 /// Return the existing or a new name for a Value. 652 StringRef CppEmitter::getOrCreateName(Value val) { 653 if (!valueMapper.count(val)) 654 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); 655 return *valueMapper.begin(val); 656 } 657 658 /// Return the existing or a new label for a Block. 659 StringRef CppEmitter::getOrCreateName(Block &block) { 660 if (!blockMapper.count(&block)) 661 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); 662 return *blockMapper.begin(&block); 663 } 664 665 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { 666 switch (val) { 667 case IntegerType::Signless: 668 return false; 669 case IntegerType::Signed: 670 return false; 671 case IntegerType::Unsigned: 672 return true; 673 } 674 llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); 675 } 676 677 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } 678 679 bool CppEmitter::hasBlockLabel(Block &block) { 680 return blockMapper.count(&block); 681 } 682 683 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { 684 auto printInt = [&](APInt val, bool isUnsigned) { 685 if (val.getBitWidth() == 1) { 686 if (val.getBoolValue()) 687 os << "true"; 688 else 689 os << "false"; 690 } else { 691 SmallString<128> strValue; 692 val.toString(strValue, 10, !isUnsigned, false); 693 os << strValue; 694 } 695 }; 696 697 auto printFloat = [&](APFloat val) { 698 if (val.isFinite()) { 699 SmallString<128> strValue; 700 // Use default values of toString except don't truncate zeros. 701 val.toString(strValue, 0, 0, false); 702 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { 703 case llvm::APFloatBase::S_IEEEsingle: 704 os << "(float)"; 705 break; 706 case llvm::APFloatBase::S_IEEEdouble: 707 os << "(double)"; 708 break; 709 default: 710 break; 711 }; 712 os << strValue; 713 } else if (val.isNaN()) { 714 os << "NAN"; 715 } else if (val.isInfinity()) { 716 if (val.isNegative()) 717 os << "-"; 718 os << "INFINITY"; 719 } 720 }; 721 722 // Print floating point attributes. 723 if (auto fAttr = attr.dyn_cast<FloatAttr>()) { 724 printFloat(fAttr.getValue()); 725 return success(); 726 } 727 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) { 728 os << '{'; 729 interleaveComma(dense, os, [&](APFloat val) { printFloat(val); }); 730 os << '}'; 731 return success(); 732 } 733 734 // Print integer attributes. 735 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) { 736 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) { 737 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); 738 return success(); 739 } 740 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) { 741 printInt(iAttr.getValue(), false); 742 return success(); 743 } 744 } 745 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) { 746 if (auto iType = dense.getType() 747 .cast<TensorType>() 748 .getElementType() 749 .dyn_cast<IntegerType>()) { 750 os << '{'; 751 interleaveComma(dense, os, [&](APInt val) { 752 printInt(val, shouldMapToUnsigned(iType.getSignedness())); 753 }); 754 os << '}'; 755 return success(); 756 } 757 if (auto iType = dense.getType() 758 .cast<TensorType>() 759 .getElementType() 760 .dyn_cast<IndexType>()) { 761 os << '{'; 762 interleaveComma(dense, os, [&](APInt val) { printInt(val, false); }); 763 os << '}'; 764 return success(); 765 } 766 } 767 768 // Print opaque attributes. 769 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) { 770 os << oAttr.getValue(); 771 return success(); 772 } 773 774 // Print symbolic reference attributes. 775 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) { 776 if (sAttr.getNestedReferences().size() > 1) 777 return emitError(loc, "attribute has more than 1 nested reference"); 778 os << sAttr.getRootReference().getValue(); 779 return success(); 780 } 781 782 // Print type attributes. 783 if (auto type = attr.dyn_cast<TypeAttr>()) 784 return emitType(loc, type.getValue()); 785 786 return emitError(loc, "cannot emit attribute of type ") << attr.getType(); 787 } 788 789 LogicalResult CppEmitter::emitOperands(Operation &op) { 790 auto emitOperandName = [&](Value result) -> LogicalResult { 791 if (!hasValueInScope(result)) 792 return op.emitOpError() << "operand value not in scope"; 793 os << getOrCreateName(result); 794 return success(); 795 }; 796 return interleaveCommaWithError(op.getOperands(), os, emitOperandName); 797 } 798 799 LogicalResult 800 CppEmitter::emitOperandsAndAttributes(Operation &op, 801 ArrayRef<StringRef> exclude) { 802 if (failed(emitOperands(op))) 803 return failure(); 804 // Insert comma in between operands and non-filtered attributes if needed. 805 if (op.getNumOperands() > 0) { 806 for (NamedAttribute attr : op.getAttrs()) { 807 if (!llvm::is_contained(exclude, attr.first.strref())) { 808 os << ", "; 809 break; 810 } 811 } 812 } 813 // Emit attributes. 814 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { 815 if (llvm::is_contained(exclude, attr.first.strref())) 816 return success(); 817 os << "/* " << attr.first << " */"; 818 if (failed(emitAttribute(op.getLoc(), attr.second))) 819 return failure(); 820 return success(); 821 }; 822 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); 823 } 824 825 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { 826 if (!hasValueInScope(result)) { 827 return result.getDefiningOp()->emitOpError( 828 "result variable for the operation has not been declared"); 829 } 830 os << getOrCreateName(result) << " = "; 831 return success(); 832 } 833 834 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, 835 bool trailingSemicolon) { 836 if (hasValueInScope(result)) { 837 return result.getDefiningOp()->emitError( 838 "result variable for the operation already declared"); 839 } 840 if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) 841 return failure(); 842 os << " " << getOrCreateName(result); 843 if (trailingSemicolon) 844 os << ";\n"; 845 return success(); 846 } 847 848 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { 849 switch (op.getNumResults()) { 850 case 0: 851 break; 852 case 1: { 853 OpResult result = op.getResult(0); 854 if (shouldDeclareVariablesAtTop()) { 855 if (failed(emitVariableAssignment(result))) 856 return failure(); 857 } else { 858 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) 859 return failure(); 860 os << " = "; 861 } 862 break; 863 } 864 default: 865 if (!shouldDeclareVariablesAtTop()) { 866 for (OpResult result : op.getResults()) { 867 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) 868 return failure(); 869 } 870 } 871 os << "std::tie("; 872 interleaveComma(op.getResults(), os, 873 [&](Value result) { os << getOrCreateName(result); }); 874 os << ") = "; 875 } 876 return success(); 877 } 878 879 LogicalResult CppEmitter::emitLabel(Block &block) { 880 if (!hasBlockLabel(block)) 881 return block.getParentOp()->emitError("label for block not found"); 882 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block 883 // label instead of using `getOStream`. 884 os.getOStream() << getOrCreateName(block) << ":\n"; 885 return success(); 886 } 887 888 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { 889 LogicalResult status = 890 llvm::TypeSwitch<Operation *, LogicalResult>(&op) 891 // EmitC ops. 892 .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp, 893 emitc::IncludeOp>( 894 [&](auto op) { return printOperation(*this, op); }) 895 // SCF ops. 896 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>( 897 [&](auto op) { return printOperation(*this, op); }) 898 // Standard ops. 899 .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp, 900 ModuleOp, ReturnOp>( 901 [&](auto op) { return printOperation(*this, op); }) 902 .Default([&](Operation *) { 903 return op.emitOpError("unable to find printer for op"); 904 }); 905 906 if (failed(status)) 907 return failure(); 908 os << (trailingSemicolon ? ";\n" : "\n"); 909 return success(); 910 } 911 912 LogicalResult CppEmitter::emitType(Location loc, Type type) { 913 if (auto iType = type.dyn_cast<IntegerType>()) { 914 switch (iType.getWidth()) { 915 case 1: 916 return (os << "bool"), success(); 917 case 8: 918 case 16: 919 case 32: 920 case 64: 921 if (shouldMapToUnsigned(iType.getSignedness())) 922 return (os << "uint" << iType.getWidth() << "_t"), success(); 923 else 924 return (os << "int" << iType.getWidth() << "_t"), success(); 925 default: 926 return emitError(loc, "cannot emit integer type ") << type; 927 } 928 } 929 if (auto fType = type.dyn_cast<FloatType>()) { 930 switch (fType.getWidth()) { 931 case 32: 932 return (os << "float"), success(); 933 case 64: 934 return (os << "double"), success(); 935 default: 936 return emitError(loc, "cannot emit float type ") << type; 937 } 938 } 939 if (auto iType = type.dyn_cast<IndexType>()) 940 return (os << "size_t"), success(); 941 if (auto tType = type.dyn_cast<TensorType>()) { 942 if (!tType.hasRank()) 943 return emitError(loc, "cannot emit unranked tensor type"); 944 if (!tType.hasStaticShape()) 945 return emitError(loc, "cannot emit tensor type with non static shape"); 946 os << "Tensor<"; 947 if (failed(emitType(loc, tType.getElementType()))) 948 return failure(); 949 auto shape = tType.getShape(); 950 for (auto dimSize : shape) { 951 os << ", "; 952 os << dimSize; 953 } 954 os << ">"; 955 return success(); 956 } 957 if (auto tType = type.dyn_cast<TupleType>()) 958 return emitTupleType(loc, tType.getTypes()); 959 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) { 960 os << oType.getValue(); 961 return success(); 962 } 963 return emitError(loc, "cannot emit type ") << type; 964 } 965 966 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { 967 switch (types.size()) { 968 case 0: 969 os << "void"; 970 return success(); 971 case 1: 972 return emitType(loc, types.front()); 973 default: 974 return emitTupleType(loc, types); 975 } 976 } 977 978 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { 979 os << "std::tuple<"; 980 if (failed(interleaveCommaWithError( 981 types, os, [&](Type type) { return emitType(loc, type); }))) 982 return failure(); 983 os << ">"; 984 return success(); 985 } 986 987 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, 988 bool declareVariablesAtTop) { 989 CppEmitter emitter(os, declareVariablesAtTop); 990 return emitter.emitOperation(*op, /*trailingSemicolon=*/false); 991 } 992