1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/IR/Builders.h" 15 #include "mlir/IR/DialectImplementation.h" 16 #include "mlir/IR/FunctionImplementation.h" 17 #include "mlir/IR/MLIRContext.h" 18 #include "mlir/IR/Module.h" 19 #include "mlir/IR/StandardTypes.h" 20 21 #include "llvm/ADT/StringSwitch.h" 22 #include "llvm/AsmParser/Parser.h" 23 #include "llvm/IR/Attributes.h" 24 #include "llvm/IR/Function.h" 25 #include "llvm/IR/Type.h" 26 #include "llvm/Support/Mutex.h" 27 #include "llvm/Support/SourceMgr.h" 28 29 using namespace mlir; 30 using namespace mlir::LLVM; 31 32 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 33 34 //===----------------------------------------------------------------------===// 35 // Printing/parsing for LLVM::CmpOp. 36 //===----------------------------------------------------------------------===// 37 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { 38 p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) 39 << "\" " << op.getOperand(0) << ", " << op.getOperand(1); 40 p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); 41 p << " : " << op.lhs().getType(); 42 } 43 44 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { 45 p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) 46 << "\" " << op.getOperand(0) << ", " << op.getOperand(1); 47 p.printOptionalAttrDict(op.getAttrs(), {"predicate"}); 48 p << " : " << op.lhs().getType(); 49 } 50 51 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 52 // attribute-dict? `:` type 53 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 54 // attribute-dict? `:` type 55 template <typename CmpPredicateType> 56 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 57 Builder &builder = parser.getBuilder(); 58 59 StringAttr predicateAttr; 60 OpAsmParser::OperandType lhs, rhs; 61 Type type; 62 llvm::SMLoc predicateLoc, trailingTypeLoc; 63 if (parser.getCurrentLocation(&predicateLoc) || 64 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 65 parser.parseOperand(lhs) || parser.parseComma() || 66 parser.parseOperand(rhs) || 67 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 68 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 69 parser.resolveOperand(lhs, type, result.operands) || 70 parser.resolveOperand(rhs, type, result.operands)) 71 return failure(); 72 73 // Replace the string attribute `predicate` with an integer attribute. 74 int64_t predicateValue = 0; 75 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 76 Optional<ICmpPredicate> predicate = 77 symbolizeICmpPredicate(predicateAttr.getValue()); 78 if (!predicate) 79 return parser.emitError(predicateLoc) 80 << "'" << predicateAttr.getValue() 81 << "' is an incorrect value of the 'predicate' attribute"; 82 predicateValue = static_cast<int64_t>(predicate.getValue()); 83 } else { 84 Optional<FCmpPredicate> predicate = 85 symbolizeFCmpPredicate(predicateAttr.getValue()); 86 if (!predicate) 87 return parser.emitError(predicateLoc) 88 << "'" << predicateAttr.getValue() 89 << "' is an incorrect value of the 'predicate' attribute"; 90 predicateValue = static_cast<int64_t>(predicate.getValue()); 91 } 92 93 result.attributes[0].second = 94 parser.getBuilder().getI64IntegerAttr(predicateValue); 95 96 // The result type is either i1 or a vector type <? x i1> if the inputs are 97 // vectors. 98 auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>(); 99 auto resultType = LLVMType::getInt1Ty(dialect); 100 auto argType = type.dyn_cast<LLVM::LLVMType>(); 101 if (!argType) 102 return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); 103 if (argType.getUnderlyingType()->isVectorTy()) 104 resultType = LLVMType::getVectorTy( 105 resultType, argType.getUnderlyingType()->getVectorNumElements()); 106 107 result.addTypes({resultType}); 108 return success(); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // Printing/parsing for LLVM::AllocaOp. 113 //===----------------------------------------------------------------------===// 114 115 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { 116 auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy(); 117 118 auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()}, 119 op.getContext()); 120 121 p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; 122 if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) 123 p.printOptionalAttrDict(op.getAttrs()); 124 else 125 p.printOptionalAttrDict(op.getAttrs(), {"alignment"}); 126 p << " : " << funcTy; 127 } 128 129 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 130 // `:` type `,` type 131 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { 132 OpAsmParser::OperandType arraySize; 133 Type type, elemType; 134 llvm::SMLoc trailingTypeLoc; 135 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 136 parser.parseType(elemType) || 137 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 138 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 139 return failure(); 140 141 // Extract the result type from the trailing function type. 142 auto funcType = type.dyn_cast<FunctionType>(); 143 if (!funcType || funcType.getNumInputs() != 1 || 144 funcType.getNumResults() != 1) 145 return parser.emitError( 146 trailingTypeLoc, 147 "expected trailing function type with one argument and one result"); 148 149 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 150 return failure(); 151 152 result.addTypes({funcType.getResult(0)}); 153 return success(); 154 } 155 156 //===----------------------------------------------------------------------===// 157 // LLVM::BrOp 158 //===----------------------------------------------------------------------===// 159 160 Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) { 161 assert(index == 0 && "invalid successor index"); 162 return getOperands(); 163 } 164 165 bool BrOp::canEraseSuccessorOperand() { return true; } 166 167 //===----------------------------------------------------------------------===// 168 // LLVM::CondBrOp 169 //===----------------------------------------------------------------------===// 170 171 Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) { 172 assert(index < getNumSuccessors() && "invalid successor index"); 173 return index == 0 ? trueDestOperands() : falseDestOperands(); 174 } 175 176 bool CondBrOp::canEraseSuccessorOperand() { return true; } 177 178 //===----------------------------------------------------------------------===// 179 // Printing/parsing for LLVM::LoadOp. 180 //===----------------------------------------------------------------------===// 181 182 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { 183 p << op.getOperationName() << ' ' << op.addr(); 184 p.printOptionalAttrDict(op.getAttrs()); 185 p << " : " << op.addr().getType(); 186 } 187 188 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 189 // the resulting type wrapped in MLIR, or nullptr on error. 190 static Type getLoadStoreElementType(OpAsmParser &parser, Type type, 191 llvm::SMLoc trailingTypeLoc) { 192 auto llvmTy = type.dyn_cast<LLVM::LLVMType>(); 193 if (!llvmTy) 194 return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"), 195 nullptr; 196 if (!llvmTy.getUnderlyingType()->isPointerTy()) 197 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), 198 nullptr; 199 return llvmTy.getPointerElementTy(); 200 } 201 202 // <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type 203 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { 204 OpAsmParser::OperandType addr; 205 Type type; 206 llvm::SMLoc trailingTypeLoc; 207 208 if (parser.parseOperand(addr) || 209 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 210 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 211 parser.resolveOperand(addr, type, result.operands)) 212 return failure(); 213 214 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 215 216 result.addTypes(elemTy); 217 return success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // Printing/parsing for LLVM::StoreOp. 222 //===----------------------------------------------------------------------===// 223 224 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { 225 p << op.getOperationName() << ' ' << op.value() << ", " << op.addr(); 226 p.printOptionalAttrDict(op.getAttrs()); 227 p << " : " << op.addr().getType(); 228 } 229 230 // <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type 231 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { 232 OpAsmParser::OperandType addr, value; 233 Type type; 234 llvm::SMLoc trailingTypeLoc; 235 236 if (parser.parseOperand(value) || parser.parseComma() || 237 parser.parseOperand(addr) || 238 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 239 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 240 return failure(); 241 242 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 243 if (!elemTy) 244 return failure(); 245 246 if (parser.resolveOperand(value, elemTy, result.operands) || 247 parser.resolveOperand(addr, type, result.operands)) 248 return failure(); 249 250 return success(); 251 } 252 253 ///===---------------------------------------------------------------------===// 254 /// LLVM::InvokeOp 255 ///===---------------------------------------------------------------------===// 256 257 Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) { 258 assert(index < getNumSuccessors() && "invalid successor index"); 259 return index == 0 ? normalDestOperands() : unwindDestOperands(); 260 } 261 262 bool InvokeOp::canEraseSuccessorOperand() { return true; } 263 264 static LogicalResult verify(InvokeOp op) { 265 if (op.getNumResults() > 1) 266 return op.emitOpError("must have 0 or 1 result"); 267 268 Block *unwindDest = op.unwindDest(); 269 if (unwindDest->empty()) 270 return op.emitError( 271 "must have at least one operation in unwind destination"); 272 273 // In unwind destination, first operation must be LandingpadOp 274 if (!isa<LandingpadOp>(unwindDest->front())) 275 return op.emitError("first operation in unwind destination should be a " 276 "llvm.landingpad operation"); 277 278 return success(); 279 } 280 281 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { 282 auto callee = op.callee(); 283 bool isDirect = callee.hasValue(); 284 285 p << op.getOperationName() << ' '; 286 287 // Either function name or pointer 288 if (isDirect) 289 p.printSymbolName(callee.getValue()); 290 else 291 p << op.getOperand(0); 292 293 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 294 p << " to "; 295 p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands()); 296 p << " unwind "; 297 p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands()); 298 299 p.printOptionalAttrDict(op.getAttrs(), 300 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 301 p << " : "; 302 p.printFunctionalType( 303 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1), 304 op.getResultTypes()); 305 } 306 307 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 308 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 309 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 310 /// attribute-dict? `:` function-type 311 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { 312 SmallVector<OpAsmParser::OperandType, 8> operands; 313 FunctionType funcType; 314 SymbolRefAttr funcAttr; 315 llvm::SMLoc trailingTypeLoc; 316 Block *normalDest, *unwindDest; 317 SmallVector<Value, 4> normalOperands, unwindOperands; 318 Builder &builder = parser.getBuilder(); 319 320 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 321 // case of an indirect call, there will be 1 operand before `(`. In case of a 322 // direct call, there will be no operands and the parser will stop at the 323 // function identifier without complaining. 324 if (parser.parseOperandList(operands)) 325 return failure(); 326 bool isDirect = operands.empty(); 327 328 // Optionally parse a function identifier. 329 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 330 return failure(); 331 332 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 333 parser.parseKeyword("to") || 334 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 335 parser.parseKeyword("unwind") || 336 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 337 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 338 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 339 return failure(); 340 341 if (isDirect) { 342 // Make sure types match. 343 if (parser.resolveOperands(operands, funcType.getInputs(), 344 parser.getNameLoc(), result.operands)) 345 return failure(); 346 result.addTypes(funcType.getResults()); 347 } else { 348 // Construct the LLVM IR Dialect function type that the first operand 349 // should match. 350 if (funcType.getNumResults() > 1) 351 return parser.emitError(trailingTypeLoc, 352 "expected function with 0 or 1 result"); 353 354 auto *llvmDialect = 355 builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); 356 LLVM::LLVMType llvmResultType; 357 if (funcType.getNumResults() == 0) { 358 llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); 359 } else { 360 llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>(); 361 if (!llvmResultType) 362 return parser.emitError(trailingTypeLoc, 363 "expected result to have LLVM type"); 364 } 365 366 SmallVector<LLVM::LLVMType, 8> argTypes; 367 argTypes.reserve(funcType.getNumInputs()); 368 for (Type ty : funcType.getInputs()) { 369 if (auto argType = ty.dyn_cast<LLVM::LLVMType>()) 370 argTypes.push_back(argType); 371 else 372 return parser.emitError(trailingTypeLoc, 373 "expected LLVM types as inputs"); 374 } 375 376 auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, 377 /*isVarArg=*/false); 378 auto wrappedFuncType = llvmFuncType.getPointerTo(); 379 380 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 381 382 // Make sure that the first operand (indirect callee) matches the wrapped 383 // LLVM IR function type, and that the types of the other call operands 384 // match the types of the function arguments. 385 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 386 parser.resolveOperands(funcArguments, funcType.getInputs(), 387 parser.getNameLoc(), result.operands)) 388 return failure(); 389 390 result.addTypes(llvmResultType); 391 } 392 result.addSuccessors({normalDest, unwindDest}); 393 result.addOperands(normalOperands); 394 result.addOperands(unwindOperands); 395 396 result.addAttribute( 397 InvokeOp::getOperandSegmentSizeAttr(), 398 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 399 static_cast<int32_t>(normalOperands.size()), 400 static_cast<int32_t>(unwindOperands.size())})); 401 return success(); 402 } 403 404 ///===----------------------------------------------------------------------===// 405 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 406 ///===----------------------------------------------------------------------===// 407 408 static LogicalResult verify(LandingpadOp op) { 409 Value value; 410 if (LLVMFuncOp func = op.getParentOfType<LLVMFuncOp>()) { 411 if (!func.personality().hasValue()) 412 return op.emitError( 413 "llvm.landingpad needs to be in a function with a personality"); 414 } 415 416 if (!op.cleanup() && op.getOperands().empty()) 417 return op.emitError("landingpad instruction expects at least one clause or " 418 "cleanup attribute"); 419 420 for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { 421 value = op.getOperand(idx); 422 bool isFilter = value.getType().cast<LLVMType>().isArrayTy(); 423 if (isFilter) { 424 // FIXME: Verify filter clauses when arrays are appropriately handled 425 } else { 426 // catch - global addresses only. 427 // Bitcast ops should have global addresses as their args. 428 if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) { 429 if (auto addrOp = 430 dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp())) 431 continue; 432 return op.emitError("constant clauses expected") 433 .attachNote(bcOp.getLoc()) 434 << "global addresses expected as operand to " 435 "bitcast used in clauses for landingpad"; 436 } 437 // NullOp and AddressOfOp allowed 438 if (dyn_cast_or_null<NullOp>(value.getDefiningOp())) 439 continue; 440 if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp())) 441 continue; 442 return op.emitError("clause #") 443 << idx << " is not a known constant - null, addressof, bitcast"; 444 } 445 } 446 return success(); 447 } 448 449 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { 450 p << op.getOperationName() << (op.cleanup() ? " cleanup " : " "); 451 452 // Clauses 453 for (auto value : op.getOperands()) { 454 // Similar to llvm - if clause is an array type then it is filter 455 // clause else catch clause 456 bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy(); 457 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 458 << value.getType() << ") "; 459 } 460 461 p.printOptionalAttrDict(op.getAttrs(), {"cleanup"}); 462 463 p << ": " << op.getType(); 464 } 465 466 /// <operation> ::= `llvm.landingpad` `cleanup`? 467 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 468 static ParseResult parseLandingpadOp(OpAsmParser &parser, 469 OperationState &result) { 470 // Check for cleanup 471 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 472 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 473 474 // Parse clauses with types 475 while (succeeded(parser.parseOptionalLParen()) && 476 (succeeded(parser.parseOptionalKeyword("filter")) || 477 succeeded(parser.parseOptionalKeyword("catch")))) { 478 OpAsmParser::OperandType operand; 479 Type ty; 480 if (parser.parseOperand(operand) || parser.parseColon() || 481 parser.parseType(ty) || 482 parser.resolveOperand(operand, ty, result.operands) || 483 parser.parseRParen()) 484 return failure(); 485 } 486 487 Type type; 488 if (parser.parseColon() || parser.parseType(type)) 489 return failure(); 490 491 result.addTypes(type); 492 return success(); 493 } 494 495 //===----------------------------------------------------------------------===// 496 // Printing/parsing for LLVM::CallOp. 497 //===----------------------------------------------------------------------===// 498 499 static void printCallOp(OpAsmPrinter &p, CallOp &op) { 500 auto callee = op.callee(); 501 bool isDirect = callee.hasValue(); 502 503 // Print the direct callee if present as a function attribute, or an indirect 504 // callee (first operand) otherwise. 505 p << op.getOperationName() << ' '; 506 if (isDirect) 507 p.printSymbolName(callee.getValue()); 508 else 509 p << op.getOperand(0); 510 511 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 512 p.printOptionalAttrDict(op.getAttrs(), {"callee"}); 513 514 // Reconstruct the function MLIR function type from operand and result types. 515 SmallVector<Type, 8> argTypes( 516 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 517 518 p << " : " 519 << FunctionType::get(argTypes, op.getResultTypes(), op.getContext()); 520 } 521 522 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 523 // attribute-dict? `:` function-type 524 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { 525 SmallVector<OpAsmParser::OperandType, 8> operands; 526 Type type; 527 SymbolRefAttr funcAttr; 528 llvm::SMLoc trailingTypeLoc; 529 530 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 531 // case of an indirect call, there will be 1 operand before `(`. In case of a 532 // direct call, there will be no operands and the parser will stop at the 533 // function identifier without complaining. 534 if (parser.parseOperandList(operands)) 535 return failure(); 536 bool isDirect = operands.empty(); 537 538 // Optionally parse a function identifier. 539 if (isDirect) 540 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 541 return failure(); 542 543 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 544 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 545 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 546 return failure(); 547 548 auto funcType = type.dyn_cast<FunctionType>(); 549 if (!funcType) 550 return parser.emitError(trailingTypeLoc, "expected function type"); 551 if (isDirect) { 552 // Make sure types match. 553 if (parser.resolveOperands(operands, funcType.getInputs(), 554 parser.getNameLoc(), result.operands)) 555 return failure(); 556 result.addTypes(funcType.getResults()); 557 } else { 558 // Construct the LLVM IR Dialect function type that the first operand 559 // should match. 560 if (funcType.getNumResults() > 1) 561 return parser.emitError(trailingTypeLoc, 562 "expected function with 0 or 1 result"); 563 564 Builder &builder = parser.getBuilder(); 565 auto *llvmDialect = 566 builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); 567 LLVM::LLVMType llvmResultType; 568 if (funcType.getNumResults() == 0) { 569 llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); 570 } else { 571 llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>(); 572 if (!llvmResultType) 573 return parser.emitError(trailingTypeLoc, 574 "expected result to have LLVM type"); 575 } 576 577 SmallVector<LLVM::LLVMType, 8> argTypes; 578 argTypes.reserve(funcType.getNumInputs()); 579 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 580 auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>(); 581 if (!argType) 582 return parser.emitError(trailingTypeLoc, 583 "expected LLVM types as inputs"); 584 argTypes.push_back(argType); 585 } 586 auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, 587 /*isVarArg=*/false); 588 auto wrappedFuncType = llvmFuncType.getPointerTo(); 589 590 auto funcArguments = 591 ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); 592 593 // Make sure that the first operand (indirect callee) matches the wrapped 594 // LLVM IR function type, and that the types of the other call operands 595 // match the types of the function arguments. 596 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 597 parser.resolveOperands(funcArguments, funcType.getInputs(), 598 parser.getNameLoc(), result.operands)) 599 return failure(); 600 601 result.addTypes(llvmResultType); 602 } 603 604 return success(); 605 } 606 607 //===----------------------------------------------------------------------===// 608 // Printing/parsing for LLVM::ExtractElementOp. 609 //===----------------------------------------------------------------------===// 610 // Expects vector to be of wrapped LLVM vector type and position to be of 611 // wrapped LLVM i32 type. 612 void LLVM::ExtractElementOp::build(Builder *b, OperationState &result, 613 Value vector, Value position, 614 ArrayRef<NamedAttribute> attrs) { 615 auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>(); 616 auto llvmType = wrappedVectorType.getVectorElementType(); 617 build(b, result, llvmType, vector, position); 618 result.addAttributes(attrs); 619 } 620 621 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { 622 p << op.getOperationName() << ' ' << op.vector() << "[" << op.position() 623 << " : " << op.position().getType() << "]"; 624 p.printOptionalAttrDict(op.getAttrs()); 625 p << " : " << op.vector().getType(); 626 } 627 628 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 629 // attribute-dict? `:` type 630 static ParseResult parseExtractElementOp(OpAsmParser &parser, 631 OperationState &result) { 632 llvm::SMLoc loc; 633 OpAsmParser::OperandType vector, position; 634 Type type, positionType; 635 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 636 parser.parseLSquare() || parser.parseOperand(position) || 637 parser.parseColonType(positionType) || parser.parseRSquare() || 638 parser.parseOptionalAttrDict(result.attributes) || 639 parser.parseColonType(type) || 640 parser.resolveOperand(vector, type, result.operands) || 641 parser.resolveOperand(position, positionType, result.operands)) 642 return failure(); 643 auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>(); 644 if (!wrappedVectorType || 645 !wrappedVectorType.getUnderlyingType()->isVectorTy()) 646 return parser.emitError( 647 loc, "expected LLVM IR dialect vector type for operand #1"); 648 result.addTypes(wrappedVectorType.getVectorElementType()); 649 return success(); 650 } 651 652 //===----------------------------------------------------------------------===// 653 // Printing/parsing for LLVM::ExtractValueOp. 654 //===----------------------------------------------------------------------===// 655 656 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { 657 p << op.getOperationName() << ' ' << op.container() << op.position(); 658 p.printOptionalAttrDict(op.getAttrs(), {"position"}); 659 p << " : " << op.container().getType(); 660 } 661 662 // Extract the type at `position` in the wrapped LLVM IR aggregate type 663 // `containerType`. Position is an integer array attribute where each value 664 // is a zero-based position of the element in the aggregate type. Return the 665 // resulting type wrapped in MLIR, or nullptr on error. 666 static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser, 667 Type containerType, 668 ArrayAttr positionAttr, 669 llvm::SMLoc attributeLoc, 670 llvm::SMLoc typeLoc) { 671 auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>(); 672 if (!wrappedContainerType) 673 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 674 675 // Infer the element type from the structure type: iteratively step inside the 676 // type by taking the element type, indexed by the position attribute for 677 // structures. Check the position index before accessing, it is supposed to 678 // be in bounds. 679 for (Attribute subAttr : positionAttr) { 680 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 681 if (!positionElementAttr) 682 return parser.emitError(attributeLoc, 683 "expected an array of integer literals"), 684 nullptr; 685 int position = positionElementAttr.getInt(); 686 auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); 687 if (llvmContainerType->isArrayTy()) { 688 if (position < 0 || static_cast<unsigned>(position) >= 689 llvmContainerType->getArrayNumElements()) 690 return parser.emitError(attributeLoc, "position out of bounds"), 691 nullptr; 692 wrappedContainerType = wrappedContainerType.getArrayElementType(); 693 } else if (llvmContainerType->isStructTy()) { 694 if (position < 0 || static_cast<unsigned>(position) >= 695 llvmContainerType->getStructNumElements()) 696 return parser.emitError(attributeLoc, "position out of bounds"), 697 nullptr; 698 wrappedContainerType = 699 wrappedContainerType.getStructElementType(position); 700 } else { 701 return parser.emitError(typeLoc, 702 "expected wrapped LLVM IR structure/array type"), 703 nullptr; 704 } 705 } 706 return wrappedContainerType; 707 } 708 709 // <operation> ::= `llvm.extractvalue` ssa-use 710 // `[` integer-literal (`,` integer-literal)* `]` 711 // attribute-dict? `:` type 712 static ParseResult parseExtractValueOp(OpAsmParser &parser, 713 OperationState &result) { 714 OpAsmParser::OperandType container; 715 Type containerType; 716 ArrayAttr positionAttr; 717 llvm::SMLoc attributeLoc, trailingTypeLoc; 718 719 if (parser.parseOperand(container) || 720 parser.getCurrentLocation(&attributeLoc) || 721 parser.parseAttribute(positionAttr, "position", result.attributes) || 722 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 723 parser.getCurrentLocation(&trailingTypeLoc) || 724 parser.parseType(containerType) || 725 parser.resolveOperand(container, containerType, result.operands)) 726 return failure(); 727 728 auto elementType = getInsertExtractValueElementType( 729 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 730 if (!elementType) 731 return failure(); 732 733 result.addTypes(elementType); 734 return success(); 735 } 736 737 //===----------------------------------------------------------------------===// 738 // Printing/parsing for LLVM::InsertElementOp. 739 //===----------------------------------------------------------------------===// 740 741 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { 742 p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "[" 743 << op.position() << " : " << op.position().getType() << "]"; 744 p.printOptionalAttrDict(op.getAttrs()); 745 p << " : " << op.vector().getType(); 746 } 747 748 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 749 // attribute-dict? `:` type 750 static ParseResult parseInsertElementOp(OpAsmParser &parser, 751 OperationState &result) { 752 llvm::SMLoc loc; 753 OpAsmParser::OperandType vector, value, position; 754 Type vectorType, positionType; 755 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 756 parser.parseComma() || parser.parseOperand(vector) || 757 parser.parseLSquare() || parser.parseOperand(position) || 758 parser.parseColonType(positionType) || parser.parseRSquare() || 759 parser.parseOptionalAttrDict(result.attributes) || 760 parser.parseColonType(vectorType)) 761 return failure(); 762 763 auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>(); 764 if (!wrappedVectorType || 765 !wrappedVectorType.getUnderlyingType()->isVectorTy()) 766 return parser.emitError( 767 loc, "expected LLVM IR dialect vector type for operand #1"); 768 auto valueType = wrappedVectorType.getVectorElementType(); 769 if (!valueType) 770 return failure(); 771 772 if (parser.resolveOperand(vector, vectorType, result.operands) || 773 parser.resolveOperand(value, valueType, result.operands) || 774 parser.resolveOperand(position, positionType, result.operands)) 775 return failure(); 776 777 result.addTypes(vectorType); 778 return success(); 779 } 780 781 //===----------------------------------------------------------------------===// 782 // Printing/parsing for LLVM::InsertValueOp. 783 //===----------------------------------------------------------------------===// 784 785 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { 786 p << op.getOperationName() << ' ' << op.value() << ", " << op.container() 787 << op.position(); 788 p.printOptionalAttrDict(op.getAttrs(), {"position"}); 789 p << " : " << op.container().getType(); 790 } 791 792 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 793 // `[` integer-literal (`,` integer-literal)* `]` 794 // attribute-dict? `:` type 795 static ParseResult parseInsertValueOp(OpAsmParser &parser, 796 OperationState &result) { 797 OpAsmParser::OperandType container, value; 798 Type containerType; 799 ArrayAttr positionAttr; 800 llvm::SMLoc attributeLoc, trailingTypeLoc; 801 802 if (parser.parseOperand(value) || parser.parseComma() || 803 parser.parseOperand(container) || 804 parser.getCurrentLocation(&attributeLoc) || 805 parser.parseAttribute(positionAttr, "position", result.attributes) || 806 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 807 parser.getCurrentLocation(&trailingTypeLoc) || 808 parser.parseType(containerType)) 809 return failure(); 810 811 auto valueType = getInsertExtractValueElementType( 812 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 813 if (!valueType) 814 return failure(); 815 816 if (parser.resolveOperand(container, containerType, result.operands) || 817 parser.resolveOperand(value, valueType, result.operands)) 818 return failure(); 819 820 result.addTypes(containerType); 821 return success(); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // Printing/parsing for LLVM::ReturnOp. 826 //===----------------------------------------------------------------------===// 827 828 static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) { 829 p << op.getOperationName(); 830 p.printOptionalAttrDict(op.getAttrs()); 831 assert(op.getNumOperands() <= 1); 832 833 if (op.getNumOperands() == 0) 834 return; 835 836 p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); 837 } 838 839 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` 840 // type-list-no-parens 841 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 842 SmallVector<OpAsmParser::OperandType, 1> operands; 843 Type type; 844 845 if (parser.parseOperandList(operands) || 846 parser.parseOptionalAttrDict(result.attributes)) 847 return failure(); 848 if (operands.empty()) 849 return success(); 850 851 if (parser.parseColonType(type) || 852 parser.resolveOperand(operands[0], type, result.operands)) 853 return failure(); 854 return success(); 855 } 856 857 //===----------------------------------------------------------------------===// 858 // Verifier for LLVM::AddressOfOp. 859 //===----------------------------------------------------------------------===// 860 861 GlobalOp AddressOfOp::getGlobal() { 862 Operation *module = getParentOp(); 863 while (module && !satisfiesLLVMModule(module)) 864 module = module->getParentOp(); 865 assert(module && "unexpected operation outside of a module"); 866 return dyn_cast_or_null<LLVM::GlobalOp>( 867 mlir::SymbolTable::lookupSymbolIn(module, global_name())); 868 } 869 870 static LogicalResult verify(AddressOfOp op) { 871 auto global = op.getGlobal(); 872 if (!global) 873 return op.emitOpError( 874 "must reference a global defined by 'llvm.mlir.global'"); 875 876 if (global.getType().getPointerTo(global.addr_space().getZExtValue()) != 877 op.getResult().getType()) 878 return op.emitOpError( 879 "the type must be a pointer to the type of the referred global"); 880 881 return success(); 882 } 883 884 //===----------------------------------------------------------------------===// 885 // Builder, printer and verifier for LLVM::GlobalOp. 886 //===----------------------------------------------------------------------===// 887 888 /// Returns the name used for the linkage attribute. This *must* correspond to 889 /// the name of the attribute in ODS. 890 static StringRef getLinkageAttrName() { return "linkage"; } 891 892 void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type, 893 bool isConstant, Linkage linkage, StringRef name, 894 Attribute value, unsigned addrSpace, 895 ArrayRef<NamedAttribute> attrs) { 896 result.addAttribute(SymbolTable::getSymbolAttrName(), 897 builder->getStringAttr(name)); 898 result.addAttribute("type", TypeAttr::get(type)); 899 if (isConstant) 900 result.addAttribute("constant", builder->getUnitAttr()); 901 if (value) 902 result.addAttribute("value", value); 903 result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr( 904 static_cast<int64_t>(linkage))); 905 if (addrSpace != 0) 906 result.addAttribute("addr_space", builder->getI32IntegerAttr(addrSpace)); 907 result.attributes.append(attrs.begin(), attrs.end()); 908 result.addRegion(); 909 } 910 911 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { 912 p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' '; 913 if (op.constant()) 914 p << "constant "; 915 p.printSymbolName(op.sym_name()); 916 p << '('; 917 if (auto value = op.getValueOrNull()) 918 p.printAttribute(value); 919 p << ')'; 920 p.printOptionalAttrDict(op.getAttrs(), 921 {SymbolTable::getSymbolAttrName(), "type", "constant", 922 "value", getLinkageAttrName()}); 923 924 // Print the trailing type unless it's a string global. 925 if (op.getValueOrNull().dyn_cast_or_null<StringAttr>()) 926 return; 927 p << " : " << op.type(); 928 929 Region &initializer = op.getInitializerRegion(); 930 if (!initializer.empty()) 931 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 932 } 933 934 //===----------------------------------------------------------------------===// 935 // Verifier for LLVM::DialectCastOp. 936 //===----------------------------------------------------------------------===// 937 938 static LogicalResult verify(DialectCastOp op) { 939 auto verifyMLIRCastType = [&op](Type type) -> LogicalResult { 940 if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) { 941 if (llvmType.isVectorTy()) 942 llvmType = llvmType.getVectorElementType(); 943 if (llvmType.isIntegerTy() || llvmType.isHalfTy() || 944 llvmType.isFloatTy() || llvmType.isDoubleTy()) { 945 return success(); 946 } 947 return op.emitOpError("type must be non-index integer types, float " 948 "types, or vector of mentioned types."); 949 } 950 if (auto vectorType = type.dyn_cast<VectorType>()) { 951 if (vectorType.getShape().size() > 1) 952 return op.emitOpError("only 1-d vector is allowed"); 953 type = vectorType.getElementType(); 954 } 955 if (type.isSignlessIntOrFloat()) 956 return success(); 957 // Note that memrefs are not supported. We currently don't have a use case 958 // for it, but even if we do, there are challenges: 959 // * if we allow memrefs to cast from/to memref descriptors, then the 960 // semantics of the cast op depends on the implementation detail of the 961 // descriptor. 962 // * if we allow memrefs to cast from/to bare pointers, some users might 963 // alternatively want metadata that only present in the descriptor. 964 // 965 // TODO(timshen): re-evaluate the memref cast design when it's needed. 966 return op.emitOpError("type must be non-index integer types, float types, " 967 "or vector of mentioned types."); 968 }; 969 return failure(failed(verifyMLIRCastType(op.in().getType())) || 970 failed(verifyMLIRCastType(op.getType()))); 971 } 972 973 // Parses one of the keywords provided in the list `keywords` and returns the 974 // position of the parsed keyword in the list. If none of the keywords from the 975 // list is parsed, returns -1. 976 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 977 ArrayRef<StringRef> keywords) { 978 for (auto en : llvm::enumerate(keywords)) { 979 if (succeeded(parser.parseOptionalKeyword(en.value()))) 980 return en.index(); 981 } 982 return -1; 983 } 984 985 namespace { 986 template <typename Ty> struct EnumTraits {}; 987 988 #define REGISTER_ENUM_TYPE(Ty) \ 989 template <> struct EnumTraits<Ty> { \ 990 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 991 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 992 } 993 994 REGISTER_ENUM_TYPE(Linkage); 995 } // end namespace 996 997 template <typename EnumTy> 998 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser, 999 OperationState &result, 1000 StringRef name) { 1001 SmallVector<StringRef, 10> names; 1002 for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i) 1003 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1004 1005 int index = parseOptionalKeywordAlternative(parser, names); 1006 if (index == -1) 1007 return failure(); 1008 result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index)); 1009 return success(); 1010 } 1011 1012 // operation ::= `llvm.mlir.global` linkage `constant`? `@` identifier 1013 // `(` attribute? `)` attribute-list? (`:` type)? region? 1014 // 1015 // The type can be omitted for string attributes, in which case it will be 1016 // inferred from the value of the string as [strlen(value) x i8]. 1017 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 1018 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, 1019 getLinkageAttrName()))) 1020 return parser.emitError(parser.getCurrentLocation(), "expected linkage"); 1021 1022 if (succeeded(parser.parseOptionalKeyword("constant"))) 1023 result.addAttribute("constant", parser.getBuilder().getUnitAttr()); 1024 1025 StringAttr name; 1026 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), 1027 result.attributes) || 1028 parser.parseLParen()) 1029 return failure(); 1030 1031 Attribute value; 1032 if (parser.parseOptionalRParen()) { 1033 if (parser.parseAttribute(value, "value", result.attributes) || 1034 parser.parseRParen()) 1035 return failure(); 1036 } 1037 1038 SmallVector<Type, 1> types; 1039 if (parser.parseOptionalAttrDict(result.attributes) || 1040 parser.parseOptionalColonTypeList(types)) 1041 return failure(); 1042 1043 if (types.size() > 1) 1044 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1045 1046 Region &initRegion = *result.addRegion(); 1047 if (types.empty()) { 1048 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1049 MLIRContext *context = parser.getBuilder().getContext(); 1050 auto *dialect = context->getRegisteredDialect<LLVMDialect>(); 1051 auto arrayType = LLVM::LLVMType::getArrayTy( 1052 LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size()); 1053 types.push_back(arrayType); 1054 } else { 1055 return parser.emitError(parser.getNameLoc(), 1056 "type can only be omitted for string globals"); 1057 } 1058 } else if (parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1059 /*argTypes=*/{})) { 1060 return failure(); 1061 } 1062 1063 result.addAttribute("type", TypeAttr::get(types[0])); 1064 return success(); 1065 } 1066 1067 static LogicalResult verify(GlobalOp op) { 1068 if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) 1069 return op.emitOpError( 1070 "expects type to be a valid element type for an LLVM pointer"); 1071 if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp())) 1072 return op.emitOpError("must appear at the module level"); 1073 1074 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1075 auto type = op.getType(); 1076 if (!type.getUnderlyingType()->isArrayTy() || 1077 !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) || 1078 type.getArrayNumElements() != strAttr.getValue().size()) 1079 return op.emitOpError( 1080 "requires an i8 array type of the length equal to that of the string " 1081 "attribute"); 1082 } 1083 1084 if (Block *b = op.getInitializerBlock()) { 1085 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1086 if (ret.operand_type_begin() == ret.operand_type_end()) 1087 return op.emitOpError("initializer region cannot return void"); 1088 if (*ret.operand_type_begin() != op.getType()) 1089 return op.emitOpError("initializer region type ") 1090 << *ret.operand_type_begin() << " does not match global type " 1091 << op.getType(); 1092 1093 if (op.getValueOrNull()) 1094 return op.emitOpError("cannot have both initializer value and region"); 1095 } 1096 return success(); 1097 } 1098 1099 //===----------------------------------------------------------------------===// 1100 // Printing/parsing for LLVM::ShuffleVectorOp. 1101 //===----------------------------------------------------------------------===// 1102 // Expects vector to be of wrapped LLVM vector type and position to be of 1103 // wrapped LLVM i32 type. 1104 void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1, 1105 Value v2, ArrayAttr mask, 1106 ArrayRef<NamedAttribute> attrs) { 1107 auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>(); 1108 auto vType = LLVMType::getVectorTy( 1109 wrappedContainerType1.getVectorElementType(), mask.size()); 1110 build(b, result, vType, v1, v2, mask); 1111 result.addAttributes(attrs); 1112 } 1113 1114 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { 1115 p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " " 1116 << op.mask(); 1117 p.printOptionalAttrDict(op.getAttrs(), {"mask"}); 1118 p << " : " << op.v1().getType() << ", " << op.v2().getType(); 1119 } 1120 1121 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 1122 // `[` integer-literal (`,` integer-literal)* `]` 1123 // attribute-dict? `:` type 1124 static ParseResult parseShuffleVectorOp(OpAsmParser &parser, 1125 OperationState &result) { 1126 llvm::SMLoc loc; 1127 OpAsmParser::OperandType v1, v2; 1128 ArrayAttr maskAttr; 1129 Type typeV1, typeV2; 1130 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 1131 parser.parseComma() || parser.parseOperand(v2) || 1132 parser.parseAttribute(maskAttr, "mask", result.attributes) || 1133 parser.parseOptionalAttrDict(result.attributes) || 1134 parser.parseColonType(typeV1) || parser.parseComma() || 1135 parser.parseType(typeV2) || 1136 parser.resolveOperand(v1, typeV1, result.operands) || 1137 parser.resolveOperand(v2, typeV2, result.operands)) 1138 return failure(); 1139 auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>(); 1140 if (!wrappedContainerType1 || 1141 !wrappedContainerType1.getUnderlyingType()->isVectorTy()) 1142 return parser.emitError( 1143 loc, "expected LLVM IR dialect vector type for operand #1"); 1144 auto vType = LLVMType::getVectorTy( 1145 wrappedContainerType1.getVectorElementType(), maskAttr.size()); 1146 result.addTypes(vType); 1147 return success(); 1148 } 1149 1150 //===----------------------------------------------------------------------===// 1151 // Implementations for LLVM::LLVMFuncOp. 1152 //===----------------------------------------------------------------------===// 1153 1154 // Add the entry block to the function. 1155 Block *LLVMFuncOp::addEntryBlock() { 1156 assert(empty() && "function already has an entry block"); 1157 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 1158 1159 auto *entry = new Block; 1160 push_back(entry); 1161 1162 LLVMType type = getType(); 1163 for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i) 1164 entry->addArgument(type.getFunctionParamType(i)); 1165 return entry; 1166 } 1167 1168 void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name, 1169 LLVMType type, LLVM::Linkage linkage, 1170 ArrayRef<NamedAttribute> attrs, 1171 ArrayRef<NamedAttributeList> argAttrs) { 1172 result.addRegion(); 1173 result.addAttribute(SymbolTable::getSymbolAttrName(), 1174 builder->getStringAttr(name)); 1175 result.addAttribute("type", TypeAttr::get(type)); 1176 result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr( 1177 static_cast<int64_t>(linkage))); 1178 result.attributes.append(attrs.begin(), attrs.end()); 1179 if (argAttrs.empty()) 1180 return; 1181 1182 unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams(); 1183 assert(numInputs == argAttrs.size() && 1184 "expected as many argument attribute lists as arguments"); 1185 SmallString<8> argAttrName; 1186 for (unsigned i = 0; i < numInputs; ++i) 1187 if (auto argDict = argAttrs[i].getDictionary()) 1188 result.addAttribute(getArgAttrName(i, argAttrName), argDict); 1189 } 1190 1191 // Builds an LLVM function type from the given lists of input and output types. 1192 // Returns a null type if any of the types provided are non-LLVM types, or if 1193 // there is more than one output type. 1194 static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, 1195 ArrayRef<Type> inputs, ArrayRef<Type> outputs, 1196 impl::VariadicFlag variadicFlag) { 1197 Builder &b = parser.getBuilder(); 1198 if (outputs.size() > 1) { 1199 parser.emitError(loc, "failed to construct function type: expected zero or " 1200 "one function result"); 1201 return {}; 1202 } 1203 1204 // Convert inputs to LLVM types, exit early on error. 1205 SmallVector<LLVMType, 4> llvmInputs; 1206 for (auto t : inputs) { 1207 auto llvmTy = t.dyn_cast<LLVMType>(); 1208 if (!llvmTy) { 1209 parser.emitError(loc, "failed to construct function type: expected LLVM " 1210 "type for function arguments"); 1211 return {}; 1212 } 1213 llvmInputs.push_back(llvmTy); 1214 } 1215 1216 // Get the dialect from the input type, if any exist. Look it up in the 1217 // context otherwise. 1218 LLVMDialect *dialect = 1219 llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>() 1220 : &llvmInputs.front().getDialect(); 1221 1222 // No output is denoted as "void" in LLVM type system. 1223 LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) 1224 : outputs.front().dyn_cast<LLVMType>(); 1225 if (!llvmOutput) { 1226 parser.emitError(loc, "failed to construct function type: expected LLVM " 1227 "type for function results"); 1228 return {}; 1229 } 1230 return LLVMType::getFunctionTy(llvmOutput, llvmInputs, 1231 variadicFlag.isVariadic()); 1232 } 1233 1234 // Parses an LLVM function. 1235 // 1236 // operation ::= `llvm.func` linkage? function-signature function-attributes? 1237 // function-body 1238 // 1239 static ParseResult parseLLVMFuncOp(OpAsmParser &parser, 1240 OperationState &result) { 1241 // Default to external linkage if no keyword is provided. 1242 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, 1243 getLinkageAttrName()))) 1244 result.addAttribute(getLinkageAttrName(), 1245 parser.getBuilder().getI64IntegerAttr( 1246 static_cast<int64_t>(LLVM::Linkage::External))); 1247 1248 StringAttr nameAttr; 1249 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 1250 SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs; 1251 SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs; 1252 SmallVector<Type, 8> argTypes; 1253 SmallVector<Type, 4> resultTypes; 1254 bool isVariadic; 1255 1256 auto signatureLocation = parser.getCurrentLocation(); 1257 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1258 result.attributes) || 1259 impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs, 1260 argTypes, argAttrs, isVariadic, resultTypes, 1261 resultAttrs)) 1262 return failure(); 1263 1264 auto type = 1265 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 1266 impl::VariadicFlag(isVariadic)); 1267 if (!type) 1268 return failure(); 1269 result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type)); 1270 1271 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 1272 return failure(); 1273 impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs, 1274 resultAttrs); 1275 1276 auto *body = result.addRegion(); 1277 return parser.parseOptionalRegion( 1278 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 1279 } 1280 1281 // Print the LLVMFuncOp. Collects argument and result types and passes them to 1282 // helper functions. Drops "void" result since it cannot be parsed back. Skips 1283 // the external linkage since it is the default value. 1284 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { 1285 p << op.getOperationName() << ' '; 1286 if (op.linkage() != LLVM::Linkage::External) 1287 p << stringifyLinkage(op.linkage()) << ' '; 1288 p.printSymbolName(op.getName()); 1289 1290 LLVMType fnType = op.getType(); 1291 SmallVector<Type, 8> argTypes; 1292 SmallVector<Type, 1> resTypes; 1293 argTypes.reserve(fnType.getFunctionNumParams()); 1294 for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) 1295 argTypes.push_back(fnType.getFunctionParamType(i)); 1296 1297 LLVMType returnType = fnType.getFunctionResultType(); 1298 if (!returnType.isVoidTy()) 1299 resTypes.push_back(returnType); 1300 1301 impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes); 1302 impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(), 1303 {getLinkageAttrName()}); 1304 1305 // Print the body if this is not an external function. 1306 Region &body = op.body(); 1307 if (!body.empty()) 1308 p.printRegion(body, /*printEntryBlockArgs=*/false, 1309 /*printBlockTerminators=*/true); 1310 } 1311 1312 // Hook for OpTrait::FunctionLike, called after verifying that the 'type' 1313 // attribute is present. This can check for preconditions of the 1314 // getNumArguments hook not failing. 1315 LogicalResult LLVMFuncOp::verifyType() { 1316 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>(); 1317 if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy()) 1318 return emitOpError("requires '" + getTypeAttrName() + 1319 "' attribute of wrapped LLVM function type"); 1320 1321 return success(); 1322 } 1323 1324 // Hook for OpTrait::FunctionLike, returns the number of function arguments. 1325 // Depends on the type attribute being correct as checked by verifyType 1326 unsigned LLVMFuncOp::getNumFuncArguments() { 1327 return getType().getUnderlyingType()->getFunctionNumParams(); 1328 } 1329 1330 // Hook for OpTrait::FunctionLike, returns the number of function results. 1331 // Depends on the type attribute being correct as checked by verifyType 1332 unsigned LLVMFuncOp::getNumFuncResults() { 1333 // We model LLVM functions that return void as having zero results, 1334 // and all others as having one result. 1335 // If we modeled a void return as one result, then it would be possible to 1336 // attach an MLIR result attribute to it, and it isn't clear what semantics we 1337 // would assign to that. 1338 if (getType().getFunctionResultType().isVoidTy()) 1339 return 0; 1340 return 1; 1341 } 1342 1343 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 1344 // - functions don't have 'common' linkage 1345 // - external functions have 'external' or 'extern_weak' linkage; 1346 // - vararg is (currently) only supported for external functions; 1347 // - entry block arguments are of LLVM types and match the function signature. 1348 static LogicalResult verify(LLVMFuncOp op) { 1349 if (op.linkage() == LLVM::Linkage::Common) 1350 return op.emitOpError() 1351 << "functions cannot have '" 1352 << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; 1353 1354 if (op.isExternal()) { 1355 if (op.linkage() != LLVM::Linkage::External && 1356 op.linkage() != LLVM::Linkage::ExternWeak) 1357 return op.emitOpError() 1358 << "external functions must have '" 1359 << stringifyLinkage(LLVM::Linkage::External) << "' or '" 1360 << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; 1361 return success(); 1362 } 1363 1364 if (op.isVarArg()) 1365 return op.emitOpError("only external functions can be variadic"); 1366 1367 auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType()); 1368 unsigned numArguments = funcType->getNumParams(); 1369 Block &entryBlock = op.front(); 1370 for (unsigned i = 0; i < numArguments; ++i) { 1371 Type argType = entryBlock.getArgument(i).getType(); 1372 auto argLLVMType = argType.dyn_cast<LLVMType>(); 1373 if (!argLLVMType) 1374 return op.emitOpError("entry block argument #") 1375 << i << " is not of LLVM type"; 1376 if (funcType->getParamType(i) != argLLVMType.getUnderlyingType()) 1377 return op.emitOpError("the type of entry block argument #") 1378 << i << " does not match the function signature"; 1379 } 1380 1381 return success(); 1382 } 1383 1384 //===----------------------------------------------------------------------===// 1385 // Verification for LLVM::NullOp. 1386 //===----------------------------------------------------------------------===// 1387 1388 // Only LLVM pointer types are supported. 1389 static LogicalResult verify(LLVM::NullOp op) { 1390 auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>(); 1391 if (!llvmType || !llvmType.isPointerTy()) 1392 return op.emitOpError("expected LLVM IR pointer type"); 1393 return success(); 1394 } 1395 1396 //===----------------------------------------------------------------------===// 1397 // Utility functions for parsing atomic ops 1398 //===----------------------------------------------------------------------===// 1399 1400 // Helper function to parse a keyword into the specified attribute named by 1401 // `attrName`. The keyword must match one of the string values defined by the 1402 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 1403 // state. 1404 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 1405 StringRef attrName) { 1406 llvm::SMLoc loc; 1407 StringRef keyword; 1408 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 1409 return failure(); 1410 1411 // Replace the keyword `keyword` with an integer attribute. 1412 auto kind = symbolizeAtomicBinOp(keyword); 1413 if (!kind) { 1414 return parser.emitError(loc) 1415 << "'" << keyword << "' is an incorrect value of the '" << attrName 1416 << "' attribute"; 1417 } 1418 1419 auto value = static_cast<int64_t>(kind.getValue()); 1420 auto attr = parser.getBuilder().getI64IntegerAttr(value); 1421 result.addAttribute(attrName, attr); 1422 1423 return success(); 1424 } 1425 1426 // Helper function to parse a keyword into the specified attribute named by 1427 // `attrName`. The keyword must match one of the string values defined by the 1428 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 1429 // state. 1430 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 1431 OperationState &result, 1432 StringRef attrName) { 1433 llvm::SMLoc loc; 1434 StringRef ordering; 1435 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 1436 return failure(); 1437 1438 // Replace the keyword `ordering` with an integer attribute. 1439 auto kind = symbolizeAtomicOrdering(ordering); 1440 if (!kind) { 1441 return parser.emitError(loc) 1442 << "'" << ordering << "' is an incorrect value of the '" << attrName 1443 << "' attribute"; 1444 } 1445 1446 auto value = static_cast<int64_t>(kind.getValue()); 1447 auto attr = parser.getBuilder().getI64IntegerAttr(value); 1448 result.addAttribute(attrName, attr); 1449 1450 return success(); 1451 } 1452 1453 //===----------------------------------------------------------------------===// 1454 // Printer, parser and verifier for LLVM::AtomicRMWOp. 1455 //===----------------------------------------------------------------------===// 1456 1457 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { 1458 p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' ' 1459 << op.ptr() << ", " << op.val() << ' ' 1460 << stringifyAtomicOrdering(op.ordering()) << ' '; 1461 p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"}); 1462 p << " : " << op.res().getType(); 1463 } 1464 1465 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 1466 // attribute-dict? `:` type 1467 static ParseResult parseAtomicRMWOp(OpAsmParser &parser, 1468 OperationState &result) { 1469 LLVMType type; 1470 OpAsmParser::OperandType ptr, val; 1471 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 1472 parser.parseComma() || parser.parseOperand(val) || 1473 parseAtomicOrdering(parser, result, "ordering") || 1474 parser.parseOptionalAttrDict(result.attributes) || 1475 parser.parseColonType(type) || 1476 parser.resolveOperand(ptr, type.getPointerTo(), result.operands) || 1477 parser.resolveOperand(val, type, result.operands)) 1478 return failure(); 1479 1480 result.addTypes(type); 1481 return success(); 1482 } 1483 1484 static LogicalResult verify(AtomicRMWOp op) { 1485 auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>(); 1486 if (!ptrType.isPointerTy()) 1487 return op.emitOpError("expected LLVM IR pointer type for operand #0"); 1488 auto valType = op.val().getType().cast<LLVM::LLVMType>(); 1489 if (valType != ptrType.getPointerElementTy()) 1490 return op.emitOpError("expected LLVM IR element type for operand #0 to " 1491 "match type for operand #1"); 1492 auto resType = op.res().getType().cast<LLVM::LLVMType>(); 1493 if (resType != valType) 1494 return op.emitOpError( 1495 "expected LLVM IR result type to match type for operand #1"); 1496 if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { 1497 if (!valType.getUnderlyingType()->isFloatingPointTy()) 1498 return op.emitOpError("expected LLVM IR floating point type"); 1499 } else if (op.bin_op() == AtomicBinOp::xchg) { 1500 if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && 1501 !valType.isIntegerTy(32) && !valType.isIntegerTy(64) && 1502 !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy()) 1503 return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 1504 } else { 1505 if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && 1506 !valType.isIntegerTy(32) && !valType.isIntegerTy(64)) 1507 return op.emitOpError("expected LLVM IR integer type"); 1508 } 1509 return success(); 1510 } 1511 1512 //===----------------------------------------------------------------------===// 1513 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 1514 //===----------------------------------------------------------------------===// 1515 1516 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { 1517 p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", " 1518 << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' ' 1519 << stringifyAtomicOrdering(op.failure_ordering()); 1520 p.printOptionalAttrDict(op.getAttrs(), 1521 {"success_ordering", "failure_ordering"}); 1522 p << " : " << op.val().getType(); 1523 } 1524 1525 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 1526 // keyword keyword attribute-dict? `:` type 1527 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, 1528 OperationState &result) { 1529 auto &builder = parser.getBuilder(); 1530 LLVMType type; 1531 OpAsmParser::OperandType ptr, cmp, val; 1532 if (parser.parseOperand(ptr) || parser.parseComma() || 1533 parser.parseOperand(cmp) || parser.parseComma() || 1534 parser.parseOperand(val) || 1535 parseAtomicOrdering(parser, result, "success_ordering") || 1536 parseAtomicOrdering(parser, result, "failure_ordering") || 1537 parser.parseOptionalAttrDict(result.attributes) || 1538 parser.parseColonType(type) || 1539 parser.resolveOperand(ptr, type.getPointerTo(), result.operands) || 1540 parser.resolveOperand(cmp, type, result.operands) || 1541 parser.resolveOperand(val, type, result.operands)) 1542 return failure(); 1543 1544 auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>(); 1545 auto boolType = LLVMType::getInt1Ty(dialect); 1546 auto resultType = LLVMType::getStructTy(type, boolType); 1547 result.addTypes(resultType); 1548 1549 return success(); 1550 } 1551 1552 static LogicalResult verify(AtomicCmpXchgOp op) { 1553 auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>(); 1554 if (!ptrType.isPointerTy()) 1555 return op.emitOpError("expected LLVM IR pointer type for operand #0"); 1556 auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>(); 1557 auto valType = op.val().getType().cast<LLVM::LLVMType>(); 1558 if (cmpType != ptrType.getPointerElementTy() || cmpType != valType) 1559 return op.emitOpError("expected LLVM IR element type for operand #0 to " 1560 "match type for all other operands"); 1561 if (!valType.isPointerTy() && !valType.isIntegerTy(8) && 1562 !valType.isIntegerTy(16) && !valType.isIntegerTy(32) && 1563 !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() && 1564 !valType.isDoubleTy()) 1565 return op.emitOpError("unexpected LLVM IR type"); 1566 if (op.success_ordering() < AtomicOrdering::monotonic || 1567 op.failure_ordering() < AtomicOrdering::monotonic) 1568 return op.emitOpError("ordering must be at least 'monotonic'"); 1569 if (op.failure_ordering() == AtomicOrdering::release || 1570 op.failure_ordering() == AtomicOrdering::acq_rel) 1571 return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 1572 return success(); 1573 } 1574 1575 //===----------------------------------------------------------------------===// 1576 // Printer, parser and verifier for LLVM::FenceOp. 1577 //===----------------------------------------------------------------------===// 1578 1579 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 1580 // attribute-dict? 1581 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) { 1582 StringAttr sScope; 1583 StringRef syncscopeKeyword = "syncscope"; 1584 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 1585 if (parser.parseLParen() || 1586 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 1587 parser.parseRParen()) 1588 return failure(); 1589 } else { 1590 result.addAttribute(syncscopeKeyword, 1591 parser.getBuilder().getStringAttr("")); 1592 } 1593 if (parseAtomicOrdering(parser, result, "ordering") || 1594 parser.parseOptionalAttrDict(result.attributes)) 1595 return failure(); 1596 return success(); 1597 } 1598 1599 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { 1600 StringRef syncscopeKeyword = "syncscope"; 1601 p << op.getOperationName() << ' '; 1602 if (!op.getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 1603 p << "syncscope(" << op.getAttr(syncscopeKeyword) << ") "; 1604 p << stringifyAtomicOrdering(op.ordering()); 1605 } 1606 1607 static LogicalResult verify(FenceOp &op) { 1608 if (op.ordering() == AtomicOrdering::not_atomic || 1609 op.ordering() == AtomicOrdering::unordered || 1610 op.ordering() == AtomicOrdering::monotonic) 1611 return op.emitOpError("can be given only acquire, release, acq_rel, " 1612 "and seq_cst orderings"); 1613 return success(); 1614 } 1615 1616 //===----------------------------------------------------------------------===// 1617 // LLVMDialect initialization, type parsing, and registration. 1618 //===----------------------------------------------------------------------===// 1619 1620 namespace mlir { 1621 namespace LLVM { 1622 namespace detail { 1623 struct LLVMDialectImpl { 1624 LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} 1625 1626 llvm::LLVMContext llvmContext; 1627 llvm::Module module; 1628 1629 /// A set of LLVMTypes that are cached on construction to avoid any lookups or 1630 /// locking. 1631 LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 1632 LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty; 1633 LLVMType voidTy; 1634 1635 /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not 1636 /// multi-threaded and requires locked access to prevent race conditions. 1637 llvm::sys::SmartMutex<true> mutex; 1638 }; 1639 } // end namespace detail 1640 } // end namespace LLVM 1641 } // end namespace mlir 1642 1643 LLVMDialect::LLVMDialect(MLIRContext *context) 1644 : Dialect(getDialectNamespace(), context), 1645 impl(new detail::LLVMDialectImpl()) { 1646 addTypes<LLVMType>(); 1647 addOperations< 1648 #define GET_OP_LIST 1649 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 1650 >(); 1651 1652 // Support unknown operations because not all LLVM operations are registered. 1653 allowUnknownOperations(); 1654 1655 // Cache some of the common LLVM types to avoid the need for lookups/locking. 1656 auto &llvmContext = impl->llvmContext; 1657 /// Integer Types. 1658 impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext)); 1659 impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext)); 1660 impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext)); 1661 impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext)); 1662 impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext)); 1663 impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext)); 1664 /// Float Types. 1665 impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext)); 1666 impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext)); 1667 impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext)); 1668 impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext)); 1669 impl->x86_fp80Ty = 1670 LLVMType::get(context, llvm::Type::getX86_FP80Ty(llvmContext)); 1671 /// Other Types. 1672 impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext)); 1673 } 1674 1675 LLVMDialect::~LLVMDialect() {} 1676 1677 #define GET_OP_CLASSES 1678 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 1679 1680 llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } 1681 llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } 1682 1683 /// Parse a type registered to this dialect. 1684 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 1685 StringRef tyData = parser.getFullSymbolSpec(); 1686 1687 // LLVM is not thread-safe, so lock access to it. 1688 llvm::sys::SmartScopedLock<true> lock(impl->mutex); 1689 1690 llvm::SMDiagnostic errorMessage; 1691 llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); 1692 if (!type) 1693 return (parser.emitError(parser.getNameLoc(), errorMessage.getMessage()), 1694 nullptr); 1695 return LLVMType::get(getContext(), type); 1696 } 1697 1698 /// Print a type registered to this dialect. 1699 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 1700 auto llvmType = type.dyn_cast<LLVMType>(); 1701 assert(llvmType && "printing wrong type"); 1702 assert(llvmType.getUnderlyingType() && "no underlying LLVM type"); 1703 llvmType.getUnderlyingType()->print(os.getStream()); 1704 } 1705 1706 /// Verify LLVMIR function argument attributes. 1707 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 1708 unsigned regionIdx, 1709 unsigned argIdx, 1710 NamedAttribute argAttr) { 1711 // Check that llvm.noalias is a boolean attribute. 1712 if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>()) 1713 return op->emitError() 1714 << "llvm.noalias argument attribute of non boolean type"; 1715 return success(); 1716 } 1717 1718 //===----------------------------------------------------------------------===// 1719 // LLVMType. 1720 //===----------------------------------------------------------------------===// 1721 1722 namespace mlir { 1723 namespace LLVM { 1724 namespace detail { 1725 struct LLVMTypeStorage : public ::mlir::TypeStorage { 1726 LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {} 1727 1728 // LLVM types are pointer-unique. 1729 using KeyTy = llvm::Type *; 1730 bool operator==(const KeyTy &key) const { return key == underlyingType; } 1731 1732 static LLVMTypeStorage *construct(TypeStorageAllocator &allocator, 1733 llvm::Type *ty) { 1734 return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty); 1735 } 1736 1737 llvm::Type *underlyingType; 1738 }; 1739 } // end namespace detail 1740 } // end namespace LLVM 1741 } // end namespace mlir 1742 1743 LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { 1744 return Base::get(context, FIRST_LLVM_TYPE, llvmType); 1745 } 1746 1747 /// Get an LLVMType with an llvm type that may cause changes to the underlying 1748 /// llvm context when constructed. 1749 LLVMType LLVMType::getLocked(LLVMDialect *dialect, 1750 function_ref<llvm::Type *()> typeBuilder) { 1751 // Lock access to the llvm context and build the type. 1752 llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex); 1753 return get(dialect->getContext(), typeBuilder()); 1754 } 1755 1756 LLVMDialect &LLVMType::getDialect() { 1757 return static_cast<LLVMDialect &>(Type::getDialect()); 1758 } 1759 1760 llvm::Type *LLVMType::getUnderlyingType() const { 1761 return getImpl()->underlyingType; 1762 } 1763 1764 /// Array type utilities. 1765 LLVMType LLVMType::getArrayElementType() { 1766 return get(getContext(), getUnderlyingType()->getArrayElementType()); 1767 } 1768 unsigned LLVMType::getArrayNumElements() { 1769 return getUnderlyingType()->getArrayNumElements(); 1770 } 1771 bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); } 1772 1773 /// Vector type utilities. 1774 LLVMType LLVMType::getVectorElementType() { 1775 return get(getContext(), getUnderlyingType()->getVectorElementType()); 1776 } 1777 bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } 1778 1779 /// Function type utilities. 1780 LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { 1781 return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); 1782 } 1783 unsigned LLVMType::getFunctionNumParams() { 1784 return getUnderlyingType()->getFunctionNumParams(); 1785 } 1786 LLVMType LLVMType::getFunctionResultType() { 1787 return get( 1788 getContext(), 1789 llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType()); 1790 } 1791 bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); } 1792 1793 /// Pointer type utilities. 1794 LLVMType LLVMType::getPointerTo(unsigned addrSpace) { 1795 // Lock access to the dialect as this may modify the LLVM context. 1796 return getLocked(&getDialect(), [=] { 1797 return getUnderlyingType()->getPointerTo(addrSpace); 1798 }); 1799 } 1800 LLVMType LLVMType::getPointerElementTy() { 1801 return get(getContext(), getUnderlyingType()->getPointerElementType()); 1802 } 1803 bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } 1804 1805 /// Struct type utilities. 1806 LLVMType LLVMType::getStructElementType(unsigned i) { 1807 return get(getContext(), getUnderlyingType()->getStructElementType(i)); 1808 } 1809 unsigned LLVMType::getStructNumElements() { 1810 return getUnderlyingType()->getStructNumElements(); 1811 } 1812 bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); } 1813 1814 /// Utilities used to generate floating point types. 1815 LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { 1816 return dialect->impl->doubleTy; 1817 } 1818 LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { 1819 return dialect->impl->floatTy; 1820 } 1821 LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { 1822 return dialect->impl->halfTy; 1823 } 1824 LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) { 1825 return dialect->impl->fp128Ty; 1826 } 1827 LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) { 1828 return dialect->impl->x86_fp80Ty; 1829 } 1830 1831 /// Utilities used to generate integer types. 1832 LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) { 1833 switch (numBits) { 1834 case 1: 1835 return dialect->impl->int1Ty; 1836 case 8: 1837 return dialect->impl->int8Ty; 1838 case 16: 1839 return dialect->impl->int16Ty; 1840 case 32: 1841 return dialect->impl->int32Ty; 1842 case 64: 1843 return dialect->impl->int64Ty; 1844 case 128: 1845 return dialect->impl->int128Ty; 1846 default: 1847 break; 1848 } 1849 1850 // Lock access to the dialect as this may modify the LLVM context. 1851 return getLocked(dialect, [=] { 1852 return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits); 1853 }); 1854 } 1855 1856 /// Utilities used to generate other miscellaneous types. 1857 LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { 1858 // Lock access to the dialect as this may modify the LLVM context. 1859 return getLocked(&elementType.getDialect(), [=] { 1860 return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements); 1861 }); 1862 } 1863 LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params, 1864 bool isVarArg) { 1865 SmallVector<llvm::Type *, 8> llvmParams; 1866 for (auto param : params) 1867 llvmParams.push_back(param.getUnderlyingType()); 1868 1869 // Lock access to the dialect as this may modify the LLVM context. 1870 return getLocked(&result.getDialect(), [=] { 1871 return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams, 1872 isVarArg); 1873 }); 1874 } 1875 LLVMType LLVMType::getStructTy(LLVMDialect *dialect, 1876 ArrayRef<LLVMType> elements, bool isPacked) { 1877 SmallVector<llvm::Type *, 8> llvmElements; 1878 for (auto elt : elements) 1879 llvmElements.push_back(elt.getUnderlyingType()); 1880 1881 // Lock access to the dialect as this may modify the LLVM context. 1882 return getLocked(dialect, [=] { 1883 return llvm::StructType::get(dialect->getLLVMContext(), llvmElements, 1884 isPacked); 1885 }); 1886 } 1887 inline static SmallVector<llvm::Type *, 8> 1888 toUnderlyingTypes(ArrayRef<LLVMType> elements) { 1889 SmallVector<llvm::Type *, 8> llvmElements; 1890 for (auto elt : elements) 1891 llvmElements.push_back(elt.getUnderlyingType()); 1892 return llvmElements; 1893 } 1894 LLVMType LLVMType::createStructTy(LLVMDialect *dialect, 1895 ArrayRef<LLVMType> elements, 1896 Optional<StringRef> name, bool isPacked) { 1897 StringRef sr = name.hasValue() ? *name : ""; 1898 SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements)); 1899 return getLocked(dialect, [=] { 1900 auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr); 1901 if (!llvmElements.empty()) 1902 rv->setBody(llvmElements, isPacked); 1903 return rv; 1904 }); 1905 } 1906 LLVMType LLVMType::setStructTyBody(LLVMType structType, 1907 ArrayRef<LLVMType> elements, bool isPacked) { 1908 llvm::StructType *st = 1909 llvm::cast<llvm::StructType>(structType.getUnderlyingType()); 1910 SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements)); 1911 return getLocked(&structType.getDialect(), [=] { 1912 st->setBody(llvmElements, isPacked); 1913 return st; 1914 }); 1915 } 1916 LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { 1917 // Lock access to the dialect as this may modify the LLVM context. 1918 return getLocked(&elementType.getDialect(), [=] { 1919 return llvm::VectorType::get(elementType.getUnderlyingType(), numElements); 1920 }); 1921 } 1922 1923 LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { 1924 return dialect->impl->voidTy; 1925 } 1926 1927 bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); } 1928 1929 //===----------------------------------------------------------------------===// 1930 // Utility functions. 1931 //===----------------------------------------------------------------------===// 1932 1933 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 1934 StringRef name, StringRef value, 1935 LLVM::Linkage linkage, 1936 LLVM::LLVMDialect *llvmDialect) { 1937 assert(builder.getInsertionBlock() && 1938 builder.getInsertionBlock()->getParentOp() && 1939 "expected builder to point to a block constrained in an op"); 1940 auto module = 1941 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 1942 assert(module && "builder points to an op outside of a module"); 1943 1944 // Create the global at the entry of the module. 1945 OpBuilder moduleBuilder(module.getBodyRegion()); 1946 auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect), 1947 value.size()); 1948 auto global = moduleBuilder.create<LLVM::GlobalOp>( 1949 loc, type, /*isConstant=*/true, linkage, name, 1950 builder.getStringAttr(value)); 1951 1952 // Get the pointer to the first character in the global string. 1953 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 1954 Value cst0 = builder.create<LLVM::ConstantOp>( 1955 loc, LLVM::LLVMType::getInt64Ty(llvmDialect), 1956 builder.getIntegerAttr(builder.getIndexType(), 0)); 1957 return builder.create<LLVM::GEPOp>(loc, 1958 LLVM::LLVMType::getInt8PtrTy(llvmDialect), 1959 globalPtr, ArrayRef<Value>({cst0, cst0})); 1960 } 1961 1962 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 1963 return op->hasTrait<OpTrait::SymbolTable>() && 1964 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 1965 } 1966