1 //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/EmitC/IR/EmitC.h" 10 #include "mlir/Dialect/SCF/SCF.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Support/IndentedOstream.h" 17 #include "mlir/Target/Cpp/CppEmitter.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/ADT/StringMap.h" 21 #include "llvm/ADT/TypeSwitch.h" 22 #include "llvm/Support/Debug.h" 23 #include "llvm/Support/FormatVariadic.h" 24 25 #define DEBUG_TYPE "translate-to-cpp" 26 27 using namespace mlir; 28 using namespace mlir::emitc; 29 using llvm::formatv; 30 31 /// Convenience functions to produce interleaved output with functions returning 32 /// a LogicalResult. This is different than those in STLExtras as functions used 33 /// on each element doesn't return a string. 34 template <typename ForwardIterator, typename UnaryFunctor, 35 typename NullaryFunctor> 36 inline LogicalResult 37 interleaveWithError(ForwardIterator begin, ForwardIterator end, 38 UnaryFunctor eachFn, NullaryFunctor betweenFn) { 39 if (begin == end) 40 return success(); 41 if (failed(eachFn(*begin))) 42 return failure(); 43 ++begin; 44 for (; begin != end; ++begin) { 45 betweenFn(); 46 if (failed(eachFn(*begin))) 47 return failure(); 48 } 49 return success(); 50 } 51 52 template <typename Container, typename UnaryFunctor, typename NullaryFunctor> 53 inline LogicalResult interleaveWithError(const Container &c, 54 UnaryFunctor eachFn, 55 NullaryFunctor betweenFn) { 56 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); 57 } 58 59 template <typename Container, typename UnaryFunctor> 60 inline LogicalResult interleaveCommaWithError(const Container &c, 61 raw_ostream &os, 62 UnaryFunctor eachFn) { 63 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); 64 } 65 66 namespace { 67 /// Emitter that uses dialect specific emitters to emit C++ code. 68 struct CppEmitter { 69 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); 70 71 /// Emits attribute or returns failure. 72 LogicalResult emitAttribute(Location loc, Attribute attr); 73 74 /// Emits operation 'op' with/without training semicolon or returns failure. 75 LogicalResult emitOperation(Operation &op, bool trailingSemicolon); 76 77 /// Emits type 'type' or returns failure. 78 LogicalResult emitType(Location loc, Type type); 79 80 /// Emits array of types as a std::tuple of the emitted types. 81 /// - emits void for an empty array; 82 /// - emits the type of the only element for arrays of size one; 83 /// - emits a std::tuple otherwise; 84 LogicalResult emitTypes(Location loc, ArrayRef<Type> types); 85 86 /// Emits array of types as a std::tuple of the emitted types independently of 87 /// the array size. 88 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); 89 90 /// Emits an assignment for a variable which has been declared previously. 91 LogicalResult emitVariableAssignment(OpResult result); 92 93 /// Emits a variable declaration for a result of an operation. 94 LogicalResult emitVariableDeclaration(OpResult result, 95 bool trailingSemicolon); 96 97 /// Emits the variable declaration and assignment prefix for 'op'. 98 /// - emits separate variable followed by std::tie for multi-valued operation; 99 /// - emits single type followed by variable for single result; 100 /// - emits nothing if no value produced by op; 101 /// Emits final '=' operator where a type is produced. Returns failure if 102 /// any result type could not be converted. 103 LogicalResult emitAssignPrefix(Operation &op); 104 105 /// Emits a label for the block. 106 LogicalResult emitLabel(Block &block); 107 108 /// Emits the operands and atttributes of the operation. All operands are 109 /// emitted first and then all attributes in alphabetical order. 110 LogicalResult emitOperandsAndAttributes(Operation &op, 111 ArrayRef<StringRef> exclude = {}); 112 113 /// Emits the operands of the operation. All operands are emitted in order. 114 LogicalResult emitOperands(Operation &op); 115 116 /// Return the existing or a new name for a Value. 117 StringRef getOrCreateName(Value val); 118 119 /// Return the existing or a new label of a Block. 120 StringRef getOrCreateName(Block &block); 121 122 /// Whether to map an mlir integer to a unsigned integer in C++. 123 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); 124 125 /// RAII helper function to manage entering/exiting C++ scopes. 126 struct Scope { 127 Scope(CppEmitter &emitter) 128 : valueMapperScope(emitter.valueMapper), 129 blockMapperScope(emitter.blockMapper), emitter(emitter) { 130 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); 131 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top()); 132 } 133 ~Scope() { 134 emitter.valueInScopeCount.pop(); 135 emitter.labelInScopeCount.pop(); 136 } 137 138 private: 139 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; 140 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; 141 CppEmitter &emitter; 142 }; 143 144 /// Returns wether the Value is assigned to a C++ variable in the scope. 145 bool hasValueInScope(Value val); 146 147 // Returns whether a label is assigned to the block. 148 bool hasBlockLabel(Block &block); 149 150 /// Returns the output stream. 151 raw_indented_ostream &ostream() { return os; }; 152 153 /// Returns if all variables for op results and basic block arguments need to 154 /// be declared at the beginning of a function. 155 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; 156 157 private: 158 using ValueMapper = llvm::ScopedHashTable<Value, std::string>; 159 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; 160 161 /// Output stream to emit to. 162 raw_indented_ostream os; 163 164 /// Boolean to enforce that all variables for op results and block 165 /// arguments are declared at the beginning of the function. This also 166 /// includes results from ops located in nested regions. 167 bool declareVariablesAtTop; 168 169 /// Map from value to name of C++ variable that contain the name. 170 ValueMapper valueMapper; 171 172 /// Map from block to name of C++ label. 173 BlockMapper blockMapper; 174 175 /// The number of values in the current scope. This is used to declare the 176 /// names of values in a scope. 177 std::stack<int64_t> valueInScopeCount; 178 std::stack<int64_t> labelInScopeCount; 179 }; 180 } // namespace 181 182 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, 183 Attribute value) { 184 OpResult result = operation->getResult(0); 185 186 // Only emit an assignment as the variable was already declared when printing 187 // the FuncOp. 188 if (emitter.shouldDeclareVariablesAtTop()) { 189 // Skip the assignment if the emitc.constant has no value. 190 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 191 if (oAttr.getValue().empty()) 192 return success(); 193 } 194 195 if (failed(emitter.emitVariableAssignment(result))) 196 return failure(); 197 return emitter.emitAttribute(operation->getLoc(), value); 198 } 199 200 // Emit a variable declaration for an emitc.constant op without value. 201 if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) { 202 if (oAttr.getValue().empty()) 203 // The semicolon gets printed by the emitOperation function. 204 return emitter.emitVariableDeclaration(result, 205 /*trailingSemicolon=*/false); 206 } 207 208 // Emit a variable declaration. 209 if (failed(emitter.emitAssignPrefix(*operation))) 210 return failure(); 211 return emitter.emitAttribute(operation->getLoc(), value); 212 } 213 214 static LogicalResult printOperation(CppEmitter &emitter, 215 emitc::ConstantOp constantOp) { 216 Operation *operation = constantOp.getOperation(); 217 Attribute value = constantOp.value(); 218 219 return printConstantOp(emitter, operation, value); 220 } 221 222 static LogicalResult printOperation(CppEmitter &emitter, 223 arith::ConstantOp constantOp) { 224 Operation *operation = constantOp.getOperation(); 225 Attribute value = constantOp.getValue(); 226 227 return printConstantOp(emitter, operation, value); 228 } 229 230 static LogicalResult printOperation(CppEmitter &emitter, 231 mlir::ConstantOp constantOp) { 232 Operation *operation = constantOp.getOperation(); 233 Attribute value = constantOp.getValue(); 234 235 return printConstantOp(emitter, operation, value); 236 } 237 238 static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) { 239 raw_ostream &os = emitter.ostream(); 240 Block &successor = *branchOp.getSuccessor(); 241 242 for (auto pair : 243 llvm::zip(branchOp.getOperands(), successor.getArguments())) { 244 Value &operand = std::get<0>(pair); 245 BlockArgument &argument = std::get<1>(pair); 246 os << emitter.getOrCreateName(argument) << " = " 247 << emitter.getOrCreateName(operand) << ";\n"; 248 } 249 250 os << "goto "; 251 if (!(emitter.hasBlockLabel(successor))) 252 return branchOp.emitOpError("unable to find label for successor block"); 253 os << emitter.getOrCreateName(successor); 254 return success(); 255 } 256 257 static LogicalResult printOperation(CppEmitter &emitter, 258 CondBranchOp condBranchOp) { 259 raw_indented_ostream &os = emitter.ostream(); 260 Block &trueSuccessor = *condBranchOp.getTrueDest(); 261 Block &falseSuccessor = *condBranchOp.getFalseDest(); 262 263 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) 264 << ") {\n"; 265 266 os.indent(); 267 268 // If condition is true. 269 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), 270 trueSuccessor.getArguments())) { 271 Value &operand = std::get<0>(pair); 272 BlockArgument &argument = std::get<1>(pair); 273 os << emitter.getOrCreateName(argument) << " = " 274 << emitter.getOrCreateName(operand) << ";\n"; 275 } 276 277 os << "goto "; 278 if (!(emitter.hasBlockLabel(trueSuccessor))) { 279 return condBranchOp.emitOpError("unable to find label for successor block"); 280 } 281 os << emitter.getOrCreateName(trueSuccessor) << ";\n"; 282 os.unindent() << "} else {\n"; 283 os.indent(); 284 // If condition is false. 285 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), 286 falseSuccessor.getArguments())) { 287 Value &operand = std::get<0>(pair); 288 BlockArgument &argument = std::get<1>(pair); 289 os << emitter.getOrCreateName(argument) << " = " 290 << emitter.getOrCreateName(operand) << ";\n"; 291 } 292 293 os << "goto "; 294 if (!(emitter.hasBlockLabel(falseSuccessor))) { 295 return condBranchOp.emitOpError() 296 << "unable to find label for successor block"; 297 } 298 os << emitter.getOrCreateName(falseSuccessor) << ";\n"; 299 os.unindent() << "}"; 300 return success(); 301 } 302 303 static LogicalResult printOperation(CppEmitter &emitter, mlir::CallOp callOp) { 304 if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) 305 return failure(); 306 307 raw_ostream &os = emitter.ostream(); 308 os << callOp.getCallee() << "("; 309 if (failed(emitter.emitOperands(*callOp.getOperation()))) 310 return failure(); 311 os << ")"; 312 return success(); 313 } 314 315 static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { 316 raw_ostream &os = emitter.ostream(); 317 Operation &op = *callOp.getOperation(); 318 319 if (failed(emitter.emitAssignPrefix(op))) 320 return failure(); 321 os << callOp.callee(); 322 323 auto emitArgs = [&](Attribute attr) -> LogicalResult { 324 if (auto t = attr.dyn_cast<IntegerAttr>()) { 325 // Index attributes are treated specially as operand index. 326 if (t.getType().isIndex()) { 327 int64_t idx = t.getInt(); 328 if ((idx < 0) || (idx >= op.getNumOperands())) 329 return op.emitOpError("invalid operand index"); 330 if (!emitter.hasValueInScope(op.getOperand(idx))) 331 return op.emitOpError("operand ") 332 << idx << "'s value not defined in scope"; 333 os << emitter.getOrCreateName(op.getOperand(idx)); 334 return success(); 335 } 336 } 337 if (failed(emitter.emitAttribute(op.getLoc(), attr))) 338 return failure(); 339 340 return success(); 341 }; 342 343 if (callOp.template_args()) { 344 os << "<"; 345 if (failed(interleaveCommaWithError(*callOp.template_args(), os, emitArgs))) 346 return failure(); 347 os << ">"; 348 } 349 350 os << "("; 351 352 LogicalResult emittedArgs = 353 callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs) 354 : emitter.emitOperands(op); 355 if (failed(emittedArgs)) 356 return failure(); 357 os << ")"; 358 return success(); 359 } 360 361 static LogicalResult printOperation(CppEmitter &emitter, 362 emitc::ApplyOp applyOp) { 363 raw_ostream &os = emitter.ostream(); 364 Operation &op = *applyOp.getOperation(); 365 366 if (failed(emitter.emitAssignPrefix(op))) 367 return failure(); 368 os << applyOp.applicableOperator(); 369 os << emitter.getOrCreateName(applyOp.getOperand()); 370 371 return success(); 372 } 373 374 static LogicalResult printOperation(CppEmitter &emitter, 375 emitc::IncludeOp includeOp) { 376 raw_ostream &os = emitter.ostream(); 377 378 os << "#include "; 379 if (includeOp.is_standard_include()) 380 os << "<" << includeOp.include() << ">"; 381 else 382 os << "\"" << includeOp.include() << "\""; 383 384 return success(); 385 } 386 387 static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) { 388 389 raw_indented_ostream &os = emitter.ostream(); 390 391 OperandRange operands = forOp.getIterOperands(); 392 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs(); 393 Operation::result_range results = forOp.getResults(); 394 395 if (!emitter.shouldDeclareVariablesAtTop()) { 396 for (OpResult result : results) { 397 if (failed(emitter.emitVariableDeclaration(result, 398 /*trailingSemicolon=*/true))) 399 return failure(); 400 } 401 } 402 403 for (auto pair : llvm::zip(iterArgs, operands)) { 404 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType()))) 405 return failure(); 406 os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = "; 407 os << emitter.getOrCreateName(std::get<1>(pair)) << ";"; 408 os << "\n"; 409 } 410 411 os << "for ("; 412 if (failed( 413 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType()))) 414 return failure(); 415 os << " "; 416 os << emitter.getOrCreateName(forOp.getInductionVar()); 417 os << " = "; 418 os << emitter.getOrCreateName(forOp.lowerBound()); 419 os << "; "; 420 os << emitter.getOrCreateName(forOp.getInductionVar()); 421 os << " < "; 422 os << emitter.getOrCreateName(forOp.upperBound()); 423 os << "; "; 424 os << emitter.getOrCreateName(forOp.getInductionVar()); 425 os << " += "; 426 os << emitter.getOrCreateName(forOp.step()); 427 os << ") {\n"; 428 os.indent(); 429 430 Region &forRegion = forOp.region(); 431 auto regionOps = forRegion.getOps(); 432 433 // We skip the trailing yield op because this updates the result variables 434 // of the for op in the generated code. Instead we update the iterArgs at 435 // the end of a loop iteration and set the result variables after the for 436 // loop. 437 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { 438 if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true))) 439 return failure(); 440 } 441 442 Operation *yieldOp = forRegion.getBlocks().front().getTerminator(); 443 // Copy yield operands into iterArgs at the end of a loop iteration. 444 for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) { 445 BlockArgument iterArg = std::get<0>(pair); 446 Value operand = std::get<1>(pair); 447 os << emitter.getOrCreateName(iterArg) << " = " 448 << emitter.getOrCreateName(operand) << ";\n"; 449 } 450 451 os.unindent() << "}"; 452 453 // Copy iterArgs into results after the for loop. 454 for (auto pair : llvm::zip(results, iterArgs)) { 455 OpResult result = std::get<0>(pair); 456 BlockArgument iterArg = std::get<1>(pair); 457 os << "\n" 458 << emitter.getOrCreateName(result) << " = " 459 << emitter.getOrCreateName(iterArg) << ";"; 460 } 461 462 return success(); 463 } 464 465 static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) { 466 raw_indented_ostream &os = emitter.ostream(); 467 468 if (!emitter.shouldDeclareVariablesAtTop()) { 469 for (OpResult result : ifOp.getResults()) { 470 if (failed(emitter.emitVariableDeclaration(result, 471 /*trailingSemicolon=*/true))) 472 return failure(); 473 } 474 } 475 476 os << "if ("; 477 if (failed(emitter.emitOperands(*ifOp.getOperation()))) 478 return failure(); 479 os << ") {\n"; 480 os.indent(); 481 482 Region &thenRegion = ifOp.thenRegion(); 483 for (Operation &op : thenRegion.getOps()) { 484 // Note: This prints a superfluous semicolon if the terminating yield op has 485 // zero results. 486 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 487 return failure(); 488 } 489 490 os.unindent() << "}"; 491 492 Region &elseRegion = ifOp.elseRegion(); 493 if (!elseRegion.empty()) { 494 os << " else {\n"; 495 os.indent(); 496 497 for (Operation &op : elseRegion.getOps()) { 498 // Note: This prints a superfluous semicolon if the terminating yield op 499 // has zero results. 500 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) 501 return failure(); 502 } 503 504 os.unindent() << "}"; 505 } 506 507 return success(); 508 } 509 510 static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) { 511 raw_ostream &os = emitter.ostream(); 512 Operation &parentOp = *yieldOp.getOperation()->getParentOp(); 513 514 if (yieldOp.getNumOperands() != parentOp.getNumResults()) { 515 return yieldOp.emitError("number of operands does not to match the number " 516 "of the parent op's results"); 517 } 518 519 if (failed(interleaveWithError( 520 llvm::zip(parentOp.getResults(), yieldOp.getOperands()), 521 [&](auto pair) -> LogicalResult { 522 auto result = std::get<0>(pair); 523 auto operand = std::get<1>(pair); 524 os << emitter.getOrCreateName(result) << " = "; 525 526 if (!emitter.hasValueInScope(operand)) 527 return yieldOp.emitError("operand value not in scope"); 528 os << emitter.getOrCreateName(operand); 529 return success(); 530 }, 531 [&]() { os << ";\n"; }))) 532 return failure(); 533 534 return success(); 535 } 536 537 static LogicalResult printOperation(CppEmitter &emitter, ReturnOp returnOp) { 538 raw_ostream &os = emitter.ostream(); 539 os << "return"; 540 switch (returnOp.getNumOperands()) { 541 case 0: 542 return success(); 543 case 1: 544 os << " " << emitter.getOrCreateName(returnOp.getOperand(0)); 545 return success(emitter.hasValueInScope(returnOp.getOperand(0))); 546 default: 547 os << " std::make_tuple("; 548 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) 549 return failure(); 550 os << ")"; 551 return success(); 552 } 553 } 554 555 static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { 556 CppEmitter::Scope scope(emitter); 557 558 for (Operation &op : moduleOp) { 559 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) 560 return failure(); 561 } 562 return success(); 563 } 564 565 static LogicalResult printOperation(CppEmitter &emitter, FuncOp functionOp) { 566 // We need to declare variables at top if the function has multiple blocks. 567 if (!emitter.shouldDeclareVariablesAtTop() && 568 functionOp.getBlocks().size() > 1) { 569 return functionOp.emitOpError( 570 "with multiple blocks needs variables declared at top"); 571 } 572 573 CppEmitter::Scope scope(emitter); 574 raw_indented_ostream &os = emitter.ostream(); 575 if (failed(emitter.emitTypes(functionOp.getLoc(), 576 functionOp.getType().getResults()))) 577 return failure(); 578 os << " " << functionOp.getName(); 579 580 os << "("; 581 if (failed(interleaveCommaWithError( 582 functionOp.getArguments(), os, 583 [&](BlockArgument arg) -> LogicalResult { 584 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) 585 return failure(); 586 os << " " << emitter.getOrCreateName(arg); 587 return success(); 588 }))) 589 return failure(); 590 os << ") {\n"; 591 os.indent(); 592 if (emitter.shouldDeclareVariablesAtTop()) { 593 // Declare all variables that hold op results including those from nested 594 // regions. 595 WalkResult result = 596 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult { 597 for (OpResult result : op->getResults()) { 598 if (failed(emitter.emitVariableDeclaration( 599 result, /*trailingSemicolon=*/true))) { 600 return WalkResult( 601 op->emitError("unable to declare result variable for op")); 602 } 603 } 604 return WalkResult::advance(); 605 }); 606 if (result.wasInterrupted()) 607 return failure(); 608 } 609 610 Region::BlockListType &blocks = functionOp.getBlocks(); 611 // Create label names for basic blocks. 612 for (Block &block : blocks) { 613 emitter.getOrCreateName(block); 614 } 615 616 // Declare variables for basic block arguments. 617 for (auto it = std::next(blocks.begin()); it != blocks.end(); ++it) { 618 Block &block = *it; 619 for (BlockArgument &arg : block.getArguments()) { 620 if (emitter.hasValueInScope(arg)) 621 return functionOp.emitOpError(" block argument #") 622 << arg.getArgNumber() << " is out of scope"; 623 if (failed( 624 emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) { 625 return failure(); 626 } 627 os << " " << emitter.getOrCreateName(arg) << ";\n"; 628 } 629 } 630 631 for (Block &block : blocks) { 632 // Only print a label if there is more than one block. 633 if (blocks.size() > 1) { 634 if (failed(emitter.emitLabel(block))) 635 return failure(); 636 } 637 for (Operation &op : block.getOperations()) { 638 // When generating code for an scf.if or std.cond_br op no semicolon needs 639 // to be printed after the closing brace. 640 // When generating code for an scf.for op, printing a trailing semicolon 641 // is handled within the printOperation function. 642 bool trailingSemicolon = !isa<scf::IfOp, scf::ForOp, CondBranchOp>(op); 643 644 if (failed(emitter.emitOperation( 645 op, /*trailingSemicolon=*/trailingSemicolon))) 646 return failure(); 647 } 648 } 649 os.unindent() << "}\n"; 650 return success(); 651 } 652 653 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) 654 : os(os), declareVariablesAtTop(declareVariablesAtTop) { 655 valueInScopeCount.push(0); 656 labelInScopeCount.push(0); 657 } 658 659 /// Return the existing or a new name for a Value. 660 StringRef CppEmitter::getOrCreateName(Value val) { 661 if (!valueMapper.count(val)) 662 valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); 663 return *valueMapper.begin(val); 664 } 665 666 /// Return the existing or a new label for a Block. 667 StringRef CppEmitter::getOrCreateName(Block &block) { 668 if (!blockMapper.count(&block)) 669 blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top())); 670 return *blockMapper.begin(&block); 671 } 672 673 bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { 674 switch (val) { 675 case IntegerType::Signless: 676 return false; 677 case IntegerType::Signed: 678 return false; 679 case IntegerType::Unsigned: 680 return true; 681 } 682 llvm_unreachable("Unexpected IntegerType::SignednessSemantics"); 683 } 684 685 bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); } 686 687 bool CppEmitter::hasBlockLabel(Block &block) { 688 return blockMapper.count(&block); 689 } 690 691 LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { 692 auto printInt = [&](APInt val, bool isUnsigned) { 693 if (val.getBitWidth() == 1) { 694 if (val.getBoolValue()) 695 os << "true"; 696 else 697 os << "false"; 698 } else { 699 SmallString<128> strValue; 700 val.toString(strValue, 10, !isUnsigned, false); 701 os << strValue; 702 } 703 }; 704 705 auto printFloat = [&](APFloat val) { 706 if (val.isFinite()) { 707 SmallString<128> strValue; 708 // Use default values of toString except don't truncate zeros. 709 val.toString(strValue, 0, 0, false); 710 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) { 711 case llvm::APFloatBase::S_IEEEsingle: 712 os << "(float)"; 713 break; 714 case llvm::APFloatBase::S_IEEEdouble: 715 os << "(double)"; 716 break; 717 default: 718 break; 719 }; 720 os << strValue; 721 } else if (val.isNaN()) { 722 os << "NAN"; 723 } else if (val.isInfinity()) { 724 if (val.isNegative()) 725 os << "-"; 726 os << "INFINITY"; 727 } 728 }; 729 730 // Print floating point attributes. 731 if (auto fAttr = attr.dyn_cast<FloatAttr>()) { 732 printFloat(fAttr.getValue()); 733 return success(); 734 } 735 if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) { 736 os << '{'; 737 interleaveComma(dense, os, [&](APFloat val) { printFloat(val); }); 738 os << '}'; 739 return success(); 740 } 741 742 // Print integer attributes. 743 if (auto iAttr = attr.dyn_cast<IntegerAttr>()) { 744 if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) { 745 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); 746 return success(); 747 } 748 if (auto iType = iAttr.getType().dyn_cast<IndexType>()) { 749 printInt(iAttr.getValue(), false); 750 return success(); 751 } 752 } 753 if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) { 754 if (auto iType = dense.getType() 755 .cast<TensorType>() 756 .getElementType() 757 .dyn_cast<IntegerType>()) { 758 os << '{'; 759 interleaveComma(dense, os, [&](APInt val) { 760 printInt(val, shouldMapToUnsigned(iType.getSignedness())); 761 }); 762 os << '}'; 763 return success(); 764 } 765 if (auto iType = dense.getType() 766 .cast<TensorType>() 767 .getElementType() 768 .dyn_cast<IndexType>()) { 769 os << '{'; 770 interleaveComma(dense, os, [&](APInt val) { printInt(val, false); }); 771 os << '}'; 772 return success(); 773 } 774 } 775 776 // Print opaque attributes. 777 if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) { 778 os << oAttr.getValue(); 779 return success(); 780 } 781 782 // Print symbolic reference attributes. 783 if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) { 784 if (sAttr.getNestedReferences().size() > 1) 785 return emitError(loc, "attribute has more than 1 nested reference"); 786 os << sAttr.getRootReference().getValue(); 787 return success(); 788 } 789 790 // Print type attributes. 791 if (auto type = attr.dyn_cast<TypeAttr>()) 792 return emitType(loc, type.getValue()); 793 794 return emitError(loc, "cannot emit attribute of type ") << attr.getType(); 795 } 796 797 LogicalResult CppEmitter::emitOperands(Operation &op) { 798 auto emitOperandName = [&](Value result) -> LogicalResult { 799 if (!hasValueInScope(result)) 800 return op.emitOpError() << "operand value not in scope"; 801 os << getOrCreateName(result); 802 return success(); 803 }; 804 return interleaveCommaWithError(op.getOperands(), os, emitOperandName); 805 } 806 807 LogicalResult 808 CppEmitter::emitOperandsAndAttributes(Operation &op, 809 ArrayRef<StringRef> exclude) { 810 if (failed(emitOperands(op))) 811 return failure(); 812 // Insert comma in between operands and non-filtered attributes if needed. 813 if (op.getNumOperands() > 0) { 814 for (NamedAttribute attr : op.getAttrs()) { 815 if (!llvm::is_contained(exclude, attr.first.strref())) { 816 os << ", "; 817 break; 818 } 819 } 820 } 821 // Emit attributes. 822 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { 823 if (llvm::is_contained(exclude, attr.first.strref())) 824 return success(); 825 os << "/* " << attr.first << " */"; 826 if (failed(emitAttribute(op.getLoc(), attr.second))) 827 return failure(); 828 return success(); 829 }; 830 return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); 831 } 832 833 LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { 834 if (!hasValueInScope(result)) { 835 return result.getDefiningOp()->emitOpError( 836 "result variable for the operation has not been declared"); 837 } 838 os << getOrCreateName(result) << " = "; 839 return success(); 840 } 841 842 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, 843 bool trailingSemicolon) { 844 if (hasValueInScope(result)) { 845 return result.getDefiningOp()->emitError( 846 "result variable for the operation already declared"); 847 } 848 if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) 849 return failure(); 850 os << " " << getOrCreateName(result); 851 if (trailingSemicolon) 852 os << ";\n"; 853 return success(); 854 } 855 856 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { 857 switch (op.getNumResults()) { 858 case 0: 859 break; 860 case 1: { 861 OpResult result = op.getResult(0); 862 if (shouldDeclareVariablesAtTop()) { 863 if (failed(emitVariableAssignment(result))) 864 return failure(); 865 } else { 866 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false))) 867 return failure(); 868 os << " = "; 869 } 870 break; 871 } 872 default: 873 if (!shouldDeclareVariablesAtTop()) { 874 for (OpResult result : op.getResults()) { 875 if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true))) 876 return failure(); 877 } 878 } 879 os << "std::tie("; 880 interleaveComma(op.getResults(), os, 881 [&](Value result) { os << getOrCreateName(result); }); 882 os << ") = "; 883 } 884 return success(); 885 } 886 887 LogicalResult CppEmitter::emitLabel(Block &block) { 888 if (!hasBlockLabel(block)) 889 return block.getParentOp()->emitError("label for block not found"); 890 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block 891 // label instead of using `getOStream`. 892 os.getOStream() << getOrCreateName(block) << ":\n"; 893 return success(); 894 } 895 896 LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { 897 LogicalResult status = 898 llvm::TypeSwitch<Operation *, LogicalResult>(&op) 899 // EmitC ops. 900 .Case<emitc::ApplyOp, emitc::CallOp, emitc::ConstantOp, 901 emitc::IncludeOp>( 902 [&](auto op) { return printOperation(*this, op); }) 903 // SCF ops. 904 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>( 905 [&](auto op) { return printOperation(*this, op); }) 906 // Standard ops. 907 .Case<BranchOp, mlir::CallOp, CondBranchOp, mlir::ConstantOp, FuncOp, 908 ModuleOp, ReturnOp>( 909 [&](auto op) { return printOperation(*this, op); }) 910 // Arithmetic ops. 911 .Case<arith::ConstantOp>( 912 [&](auto op) { return printOperation(*this, op); }) 913 .Default([&](Operation *) { 914 return op.emitOpError("unable to find printer for op"); 915 }); 916 917 if (failed(status)) 918 return failure(); 919 os << (trailingSemicolon ? ";\n" : "\n"); 920 return success(); 921 } 922 923 LogicalResult CppEmitter::emitType(Location loc, Type type) { 924 if (auto iType = type.dyn_cast<IntegerType>()) { 925 switch (iType.getWidth()) { 926 case 1: 927 return (os << "bool"), success(); 928 case 8: 929 case 16: 930 case 32: 931 case 64: 932 if (shouldMapToUnsigned(iType.getSignedness())) 933 return (os << "uint" << iType.getWidth() << "_t"), success(); 934 else 935 return (os << "int" << iType.getWidth() << "_t"), success(); 936 default: 937 return emitError(loc, "cannot emit integer type ") << type; 938 } 939 } 940 if (auto fType = type.dyn_cast<FloatType>()) { 941 switch (fType.getWidth()) { 942 case 32: 943 return (os << "float"), success(); 944 case 64: 945 return (os << "double"), success(); 946 default: 947 return emitError(loc, "cannot emit float type ") << type; 948 } 949 } 950 if (auto iType = type.dyn_cast<IndexType>()) 951 return (os << "size_t"), success(); 952 if (auto tType = type.dyn_cast<TensorType>()) { 953 if (!tType.hasRank()) 954 return emitError(loc, "cannot emit unranked tensor type"); 955 if (!tType.hasStaticShape()) 956 return emitError(loc, "cannot emit tensor type with non static shape"); 957 os << "Tensor<"; 958 if (failed(emitType(loc, tType.getElementType()))) 959 return failure(); 960 auto shape = tType.getShape(); 961 for (auto dimSize : shape) { 962 os << ", "; 963 os << dimSize; 964 } 965 os << ">"; 966 return success(); 967 } 968 if (auto tType = type.dyn_cast<TupleType>()) 969 return emitTupleType(loc, tType.getTypes()); 970 if (auto oType = type.dyn_cast<emitc::OpaqueType>()) { 971 os << oType.getValue(); 972 return success(); 973 } 974 return emitError(loc, "cannot emit type ") << type; 975 } 976 977 LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { 978 switch (types.size()) { 979 case 0: 980 os << "void"; 981 return success(); 982 case 1: 983 return emitType(loc, types.front()); 984 default: 985 return emitTupleType(loc, types); 986 } 987 } 988 989 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { 990 os << "std::tuple<"; 991 if (failed(interleaveCommaWithError( 992 types, os, [&](Type type) { return emitType(loc, type); }))) 993 return failure(); 994 os << ">"; 995 return success(); 996 } 997 998 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, 999 bool declareVariablesAtTop) { 1000 CppEmitter emitter(os, declareVariablesAtTop); 1001 return emitter.emitOperation(*op, /*trailingSemicolon=*/false); 1002 } 1003