1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 23 #include "llvm/ADT/StringSwitch.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/AsmParser/Parser.h" 26 #include "llvm/Bitcode/BitcodeReader.h" 27 #include "llvm/Bitcode/BitcodeWriter.h" 28 #include "llvm/IR/Attributes.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Mutex.h" 32 #include "llvm/Support/SourceMgr.h" 33 34 #include <iostream> 35 #include <numeric> 36 37 using namespace mlir; 38 using namespace mlir::LLVM; 39 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 40 41 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 42 43 static constexpr const char kVolatileAttrName[] = "volatile_"; 44 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 45 46 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 47 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 48 #define GET_ATTRDEF_CLASSES 49 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 50 51 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 52 SmallVector<NamedAttribute, 8> filteredAttrs( 53 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 54 if (attr.first == "fastmathFlags") { 55 auto defAttr = FMFAttr::get(attr.second.getContext(), {}); 56 return defAttr != attr.second; 57 } 58 return true; 59 })); 60 return filteredAttrs; 61 } 62 63 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 64 NamedAttrList &result) { 65 return parser.parseOptionalAttrDict(result); 66 } 67 68 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 69 DictionaryAttr attrs) { 70 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 71 } 72 73 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 74 /// fully defined llvm.func. 75 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 76 Operation *op, 77 SymbolTableCollection &symbolTable) { 78 StringRef name = symbol.getValue(); 79 auto func = 80 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 81 if (!func) 82 return op->emitOpError("'") 83 << name << "' does not reference a valid LLVM function"; 84 if (func.isExternal()) 85 return op->emitOpError("'") << name << "' does not have a definition"; 86 return success(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Printing/parsing for LLVM::CmpOp. 91 //===----------------------------------------------------------------------===// 92 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { 93 p << " \"" << stringifyICmpPredicate(op.getPredicate()) << "\" " 94 << op.getOperand(0) << ", " << op.getOperand(1); 95 p.printOptionalAttrDict(op->getAttrs(), {"predicate"}); 96 p << " : " << op.getLhs().getType(); 97 } 98 99 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { 100 p << " \"" << stringifyFCmpPredicate(op.getPredicate()) << "\" " 101 << op.getOperand(0) << ", " << op.getOperand(1); 102 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"}); 103 p << " : " << op.getLhs().getType(); 104 } 105 106 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 107 // attribute-dict? `:` type 108 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 109 // attribute-dict? `:` type 110 template <typename CmpPredicateType> 111 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 112 Builder &builder = parser.getBuilder(); 113 114 StringAttr predicateAttr; 115 OpAsmParser::OperandType lhs, rhs; 116 Type type; 117 llvm::SMLoc predicateLoc, trailingTypeLoc; 118 if (parser.getCurrentLocation(&predicateLoc) || 119 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 120 parser.parseOperand(lhs) || parser.parseComma() || 121 parser.parseOperand(rhs) || 122 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 123 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 124 parser.resolveOperand(lhs, type, result.operands) || 125 parser.resolveOperand(rhs, type, result.operands)) 126 return failure(); 127 128 // Replace the string attribute `predicate` with an integer attribute. 129 int64_t predicateValue = 0; 130 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 131 Optional<ICmpPredicate> predicate = 132 symbolizeICmpPredicate(predicateAttr.getValue()); 133 if (!predicate) 134 return parser.emitError(predicateLoc) 135 << "'" << predicateAttr.getValue() 136 << "' is an incorrect value of the 'predicate' attribute"; 137 predicateValue = static_cast<int64_t>(predicate.getValue()); 138 } else { 139 Optional<FCmpPredicate> predicate = 140 symbolizeFCmpPredicate(predicateAttr.getValue()); 141 if (!predicate) 142 return parser.emitError(predicateLoc) 143 << "'" << predicateAttr.getValue() 144 << "' is an incorrect value of the 'predicate' attribute"; 145 predicateValue = static_cast<int64_t>(predicate.getValue()); 146 } 147 148 result.attributes.set("predicate", 149 parser.getBuilder().getI64IntegerAttr(predicateValue)); 150 151 // The result type is either i1 or a vector type <? x i1> if the inputs are 152 // vectors. 153 Type resultType = IntegerType::get(builder.getContext(), 1); 154 if (!isCompatibleType(type)) 155 return parser.emitError(trailingTypeLoc, 156 "expected LLVM dialect-compatible type"); 157 if (LLVM::isCompatibleVectorType(type)) { 158 if (type.isa<LLVM::LLVMScalableVectorType>()) { 159 resultType = LLVM::LLVMScalableVectorType::get( 160 resultType, LLVM::getVectorNumElements(type).getKnownMinValue()); 161 } else { 162 resultType = LLVM::getFixedVectorType( 163 resultType, LLVM::getVectorNumElements(type).getFixedValue()); 164 } 165 } 166 167 result.addTypes({resultType}); 168 return success(); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Printing/parsing for LLVM::AllocaOp. 173 //===----------------------------------------------------------------------===// 174 175 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { 176 auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType(); 177 178 auto funcTy = FunctionType::get( 179 op.getContext(), {op.getArraySize().getType()}, {op.getType()}); 180 181 p << ' ' << op.getArraySize() << " x " << elemTy; 182 if (op.getAlignment().hasValue() && *op.getAlignment() != 0) 183 p.printOptionalAttrDict(op->getAttrs()); 184 else 185 p.printOptionalAttrDict(op->getAttrs(), {"alignment"}); 186 p << " : " << funcTy; 187 } 188 189 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 190 // `:` type `,` type 191 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { 192 OpAsmParser::OperandType arraySize; 193 Type type, elemType; 194 llvm::SMLoc trailingTypeLoc; 195 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 196 parser.parseType(elemType) || 197 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 198 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 199 return failure(); 200 201 Optional<NamedAttribute> alignmentAttr = 202 result.attributes.getNamed("alignment"); 203 if (alignmentAttr.hasValue()) { 204 auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>(); 205 if (!alignmentInt) 206 return parser.emitError(parser.getNameLoc(), 207 "expected integer alignment"); 208 if (alignmentInt.getValue().isNullValue()) 209 result.attributes.erase("alignment"); 210 } 211 212 // Extract the result type from the trailing function type. 213 auto funcType = type.dyn_cast<FunctionType>(); 214 if (!funcType || funcType.getNumInputs() != 1 || 215 funcType.getNumResults() != 1) 216 return parser.emitError( 217 trailingTypeLoc, 218 "expected trailing function type with one argument and one result"); 219 220 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 221 return failure(); 222 223 result.addTypes({funcType.getResult(0)}); 224 return success(); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // LLVM::BrOp 229 //===----------------------------------------------------------------------===// 230 231 Optional<MutableOperandRange> 232 BrOp::getMutableSuccessorOperands(unsigned index) { 233 assert(index == 0 && "invalid successor index"); 234 return getDestOperandsMutable(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // LLVM::CondBrOp 239 //===----------------------------------------------------------------------===// 240 241 Optional<MutableOperandRange> 242 CondBrOp::getMutableSuccessorOperands(unsigned index) { 243 assert(index < getNumSuccessors() && "invalid successor index"); 244 return index == 0 ? getTrueDestOperandsMutable() 245 : getFalseDestOperandsMutable(); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // LLVM::SwitchOp 250 //===----------------------------------------------------------------------===// 251 252 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 253 Block *defaultDestination, ValueRange defaultOperands, 254 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 255 ArrayRef<ValueRange> caseOperands, 256 ArrayRef<int32_t> branchWeights) { 257 ElementsAttr caseValuesAttr; 258 if (!caseValues.empty()) 259 caseValuesAttr = builder.getI32VectorAttr(caseValues); 260 261 ElementsAttr weightsAttr; 262 if (!branchWeights.empty()) 263 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 264 265 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 266 weightsAttr, defaultDestination, caseDestinations); 267 } 268 269 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 270 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 271 static ParseResult parseSwitchOpCases( 272 OpAsmParser &parser, ElementsAttr &caseValues, 273 SmallVectorImpl<Block *> &caseDestinations, 274 SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands, 275 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 276 SmallVector<int32_t> values; 277 int32_t value = 0; 278 do { 279 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 280 if (values.empty() && !integerParseResult.hasValue()) 281 return success(); 282 283 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 284 return failure(); 285 values.push_back(value); 286 287 Block *destination; 288 SmallVector<OpAsmParser::OperandType> operands; 289 SmallVector<Type> operandTypes; 290 if (parser.parseColon() || parser.parseSuccessor(destination)) 291 return failure(); 292 if (!parser.parseOptionalLParen()) { 293 if (parser.parseRegionArgumentList(operands) || 294 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 295 return failure(); 296 } 297 caseDestinations.push_back(destination); 298 caseOperands.emplace_back(operands); 299 caseOperandTypes.emplace_back(operandTypes); 300 } while (!parser.parseOptionalComma()); 301 302 caseValues = parser.getBuilder().getI32VectorAttr(values); 303 return success(); 304 } 305 306 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, 307 ElementsAttr caseValues, 308 SuccessorRange caseDestinations, 309 OperandRangeRange caseOperands, 310 TypeRangeRange caseOperandTypes) { 311 if (!caseValues) 312 return; 313 314 size_t index = 0; 315 llvm::interleave( 316 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 317 [&](auto i) { 318 p << " "; 319 p << std::get<0>(i).getLimitedValue(); 320 p << ": "; 321 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 322 }, 323 [&] { 324 p << ','; 325 p.printNewline(); 326 }); 327 p.printNewline(); 328 } 329 330 static LogicalResult verify(SwitchOp op) { 331 if ((!op.getCaseValues() && !op.getCaseDestinations().empty()) || 332 (op.getCaseValues() && 333 op.getCaseValues()->size() != 334 static_cast<int64_t>(op.getCaseDestinations().size()))) 335 return op.emitOpError("expects number of case values to match number of " 336 "case destinations"); 337 if (op.getBranchWeights() && 338 op.getBranchWeights()->size() != op.getNumSuccessors()) 339 return op.emitError("expects number of branch weights to match number of " 340 "successors: ") 341 << op.getBranchWeights()->size() << " vs " << op.getNumSuccessors(); 342 return success(); 343 } 344 345 Optional<MutableOperandRange> 346 SwitchOp::getMutableSuccessorOperands(unsigned index) { 347 assert(index < getNumSuccessors() && "invalid successor index"); 348 return index == 0 ? getDefaultOperandsMutable() 349 : getCaseOperandsMutable(index - 1); 350 } 351 352 //===----------------------------------------------------------------------===// 353 // Builder, printer and parser for for LLVM::LoadOp. 354 //===----------------------------------------------------------------------===// 355 356 LogicalResult verifySymbolAttribute( 357 Operation *op, StringRef attributeName, 358 std::function<LogicalResult(Operation *, SymbolRefAttr)> verifySymbolType) { 359 if (Attribute attribute = op->getAttr(attributeName)) { 360 // The attribute is already verified to be a symbol ref array attribute via 361 // a constraint in the operation definition. 362 for (SymbolRefAttr symbolRef : 363 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 364 StringAttr metadataName = symbolRef.getRootReference(); 365 StringAttr symbolName = symbolRef.getLeafReference(); 366 // We want @metadata::@symbol, not just @symbol 367 if (metadataName == symbolName) { 368 return op->emitOpError() << "expected '" << symbolRef 369 << "' to specify a fully qualified reference"; 370 } 371 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 372 op->getParentOp(), metadataName); 373 if (!metadataOp) 374 return op->emitOpError() 375 << "expected '" << symbolRef << "' to reference a metadata op"; 376 Operation *symbolOp = 377 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 378 if (!symbolOp) 379 return op->emitOpError() 380 << "expected '" << symbolRef << "' to be a valid reference"; 381 if (failed(verifySymbolType(symbolOp, symbolRef))) { 382 return failure(); 383 } 384 } 385 } 386 return success(); 387 } 388 389 // Verifies that metadata ops are wired up properly. 390 template <typename OpTy> 391 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 392 auto verifySymbolType = [op](Operation *symbolOp, 393 SymbolRefAttr symbolRef) -> LogicalResult { 394 if (!isa<OpTy>(symbolOp)) { 395 return op->emitOpError() 396 << "expected '" << symbolRef << "' to resolve to a " 397 << OpTy::getOperationName(); 398 } 399 return success(); 400 }; 401 402 return verifySymbolAttribute(op, attributeName, verifySymbolType); 403 } 404 405 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 406 // access_groups 407 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 408 op, LLVMDialect::getAccessGroupsAttrName()))) 409 return failure(); 410 411 // alias_scopes 412 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 413 op, LLVMDialect::getAliasScopesAttrName()))) 414 return failure(); 415 416 // noalias_scopes 417 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 418 op, LLVMDialect::getNoAliasScopesAttrName()))) 419 return failure(); 420 421 return success(); 422 } 423 424 static LogicalResult verify(LoadOp op) { 425 return verifyMemoryOpMetadata(op.getOperation()); 426 } 427 428 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 429 Value addr, unsigned alignment, bool isVolatile, 430 bool isNonTemporal) { 431 result.addOperands(addr); 432 result.addTypes(t); 433 if (isVolatile) 434 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 435 if (isNonTemporal) 436 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 437 if (alignment != 0) 438 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 439 } 440 441 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { 442 p << ' '; 443 if (op.getVolatile_()) 444 p << "volatile "; 445 p << op.getAddr(); 446 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 447 p << " : " << op.getAddr().getType(); 448 } 449 450 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 451 // the resulting type wrapped in MLIR, or nullptr on error. 452 static Type getLoadStoreElementType(OpAsmParser &parser, Type type, 453 llvm::SMLoc trailingTypeLoc) { 454 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 455 if (!llvmTy) 456 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), 457 nullptr; 458 return llvmTy.getElementType(); 459 } 460 461 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 462 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { 463 OpAsmParser::OperandType addr; 464 Type type; 465 llvm::SMLoc trailingTypeLoc; 466 467 if (succeeded(parser.parseOptionalKeyword("volatile"))) 468 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 469 470 if (parser.parseOperand(addr) || 471 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 472 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 473 parser.resolveOperand(addr, type, result.operands)) 474 return failure(); 475 476 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 477 478 result.addTypes(elemTy); 479 return success(); 480 } 481 482 //===----------------------------------------------------------------------===// 483 // Builder, printer and parser for LLVM::StoreOp. 484 //===----------------------------------------------------------------------===// 485 486 static LogicalResult verify(StoreOp op) { 487 return verifyMemoryOpMetadata(op.getOperation()); 488 } 489 490 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 491 Value addr, unsigned alignment, bool isVolatile, 492 bool isNonTemporal) { 493 result.addOperands({value, addr}); 494 result.addTypes({}); 495 if (isVolatile) 496 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 497 if (isNonTemporal) 498 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 499 if (alignment != 0) 500 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 501 } 502 503 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { 504 p << ' '; 505 if (op.getVolatile_()) 506 p << "volatile "; 507 p << op.getValue() << ", " << op.getAddr(); 508 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 509 p << " : " << op.getAddr().getType(); 510 } 511 512 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 513 // attribute-dict? `:` type 514 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { 515 OpAsmParser::OperandType addr, value; 516 Type type; 517 llvm::SMLoc trailingTypeLoc; 518 519 if (succeeded(parser.parseOptionalKeyword("volatile"))) 520 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 521 522 if (parser.parseOperand(value) || parser.parseComma() || 523 parser.parseOperand(addr) || 524 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 525 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 526 return failure(); 527 528 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 529 if (!elemTy) 530 return failure(); 531 532 if (parser.resolveOperand(value, elemTy, result.operands) || 533 parser.resolveOperand(addr, type, result.operands)) 534 return failure(); 535 536 return success(); 537 } 538 539 ///===---------------------------------------------------------------------===// 540 /// LLVM::InvokeOp 541 ///===---------------------------------------------------------------------===// 542 543 Optional<MutableOperandRange> 544 InvokeOp::getMutableSuccessorOperands(unsigned index) { 545 assert(index < getNumSuccessors() && "invalid successor index"); 546 return index == 0 ? getNormalDestOperandsMutable() 547 : getUnwindDestOperandsMutable(); 548 } 549 550 static LogicalResult verify(InvokeOp op) { 551 if (op.getNumResults() > 1) 552 return op.emitOpError("must have 0 or 1 result"); 553 554 Block *unwindDest = op.getUnwindDest(); 555 if (unwindDest->empty()) 556 return op.emitError( 557 "must have at least one operation in unwind destination"); 558 559 // In unwind destination, first operation must be LandingpadOp 560 if (!isa<LandingpadOp>(unwindDest->front())) 561 return op.emitError("first operation in unwind destination should be a " 562 "llvm.landingpad operation"); 563 564 return success(); 565 } 566 567 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { 568 auto callee = op.getCallee(); 569 bool isDirect = callee.hasValue(); 570 571 p << ' '; 572 573 // Either function name or pointer 574 if (isDirect) 575 p.printSymbolName(callee.getValue()); 576 else 577 p << op.getOperand(0); 578 579 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 580 p << " to "; 581 p.printSuccessorAndUseList(op.getNormalDest(), op.getNormalDestOperands()); 582 p << " unwind "; 583 p.printSuccessorAndUseList(op.getUnwindDest(), op.getUnwindDestOperands()); 584 585 p.printOptionalAttrDict(op->getAttrs(), 586 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 587 p << " : "; 588 p.printFunctionalType( 589 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1), 590 op.getResultTypes()); 591 } 592 593 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 594 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 595 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 596 /// attribute-dict? `:` function-type 597 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { 598 SmallVector<OpAsmParser::OperandType, 8> operands; 599 FunctionType funcType; 600 SymbolRefAttr funcAttr; 601 llvm::SMLoc trailingTypeLoc; 602 Block *normalDest, *unwindDest; 603 SmallVector<Value, 4> normalOperands, unwindOperands; 604 Builder &builder = parser.getBuilder(); 605 606 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 607 // case of an indirect call, there will be 1 operand before `(`. In case of a 608 // direct call, there will be no operands and the parser will stop at the 609 // function identifier without complaining. 610 if (parser.parseOperandList(operands)) 611 return failure(); 612 bool isDirect = operands.empty(); 613 614 // Optionally parse a function identifier. 615 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 616 return failure(); 617 618 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 619 parser.parseKeyword("to") || 620 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 621 parser.parseKeyword("unwind") || 622 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 623 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 624 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 625 return failure(); 626 627 if (isDirect) { 628 // Make sure types match. 629 if (parser.resolveOperands(operands, funcType.getInputs(), 630 parser.getNameLoc(), result.operands)) 631 return failure(); 632 result.addTypes(funcType.getResults()); 633 } else { 634 // Construct the LLVM IR Dialect function type that the first operand 635 // should match. 636 if (funcType.getNumResults() > 1) 637 return parser.emitError(trailingTypeLoc, 638 "expected function with 0 or 1 result"); 639 640 Type llvmResultType; 641 if (funcType.getNumResults() == 0) { 642 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 643 } else { 644 llvmResultType = funcType.getResult(0); 645 if (!isCompatibleType(llvmResultType)) 646 return parser.emitError(trailingTypeLoc, 647 "expected result to have LLVM type"); 648 } 649 650 SmallVector<Type, 8> argTypes; 651 argTypes.reserve(funcType.getNumInputs()); 652 for (Type ty : funcType.getInputs()) { 653 if (isCompatibleType(ty)) 654 argTypes.push_back(ty); 655 else 656 return parser.emitError(trailingTypeLoc, 657 "expected LLVM types as inputs"); 658 } 659 660 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 661 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 662 663 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 664 665 // Make sure that the first operand (indirect callee) matches the wrapped 666 // LLVM IR function type, and that the types of the other call operands 667 // match the types of the function arguments. 668 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 669 parser.resolveOperands(funcArguments, funcType.getInputs(), 670 parser.getNameLoc(), result.operands)) 671 return failure(); 672 673 result.addTypes(llvmResultType); 674 } 675 result.addSuccessors({normalDest, unwindDest}); 676 result.addOperands(normalOperands); 677 result.addOperands(unwindOperands); 678 679 result.addAttribute( 680 InvokeOp::getOperandSegmentSizeAttr(), 681 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 682 static_cast<int32_t>(normalOperands.size()), 683 static_cast<int32_t>(unwindOperands.size())})); 684 return success(); 685 } 686 687 ///===----------------------------------------------------------------------===// 688 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 689 ///===----------------------------------------------------------------------===// 690 691 static LogicalResult verify(LandingpadOp op) { 692 Value value; 693 if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) { 694 if (!func.getPersonality().hasValue()) 695 return op.emitError( 696 "llvm.landingpad needs to be in a function with a personality"); 697 } 698 699 if (!op.getCleanup() && op.getOperands().empty()) 700 return op.emitError("landingpad instruction expects at least one clause or " 701 "cleanup attribute"); 702 703 for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { 704 value = op.getOperand(idx); 705 bool isFilter = value.getType().isa<LLVMArrayType>(); 706 if (isFilter) { 707 // FIXME: Verify filter clauses when arrays are appropriately handled 708 } else { 709 // catch - global addresses only. 710 // Bitcast ops should have global addresses as their args. 711 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 712 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 713 continue; 714 return op.emitError("constant clauses expected") 715 .attachNote(bcOp.getLoc()) 716 << "global addresses expected as operand to " 717 "bitcast used in clauses for landingpad"; 718 } 719 // NullOp and AddressOfOp allowed 720 if (value.getDefiningOp<NullOp>()) 721 continue; 722 if (value.getDefiningOp<AddressOfOp>()) 723 continue; 724 return op.emitError("clause #") 725 << idx << " is not a known constant - null, addressof, bitcast"; 726 } 727 } 728 return success(); 729 } 730 731 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { 732 p << (op.getCleanup() ? " cleanup " : " "); 733 734 // Clauses 735 for (auto value : op.getOperands()) { 736 // Similar to llvm - if clause is an array type then it is filter 737 // clause else catch clause 738 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 739 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 740 << value.getType() << ") "; 741 } 742 743 p.printOptionalAttrDict(op->getAttrs(), {"cleanup"}); 744 745 p << ": " << op.getType(); 746 } 747 748 /// <operation> ::= `llvm.landingpad` `cleanup`? 749 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 750 static ParseResult parseLandingpadOp(OpAsmParser &parser, 751 OperationState &result) { 752 // Check for cleanup 753 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 754 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 755 756 // Parse clauses with types 757 while (succeeded(parser.parseOptionalLParen()) && 758 (succeeded(parser.parseOptionalKeyword("filter")) || 759 succeeded(parser.parseOptionalKeyword("catch")))) { 760 OpAsmParser::OperandType operand; 761 Type ty; 762 if (parser.parseOperand(operand) || parser.parseColon() || 763 parser.parseType(ty) || 764 parser.resolveOperand(operand, ty, result.operands) || 765 parser.parseRParen()) 766 return failure(); 767 } 768 769 Type type; 770 if (parser.parseColon() || parser.parseType(type)) 771 return failure(); 772 773 result.addTypes(type); 774 return success(); 775 } 776 777 //===----------------------------------------------------------------------===// 778 // Verifying/Printing/parsing for LLVM::CallOp. 779 //===----------------------------------------------------------------------===// 780 781 static LogicalResult verify(CallOp &op) { 782 if (op.getNumResults() > 1) 783 return op.emitOpError("must have 0 or 1 result"); 784 785 // Type for the callee, we'll get it differently depending if it is a direct 786 // or indirect call. 787 Type fnType; 788 789 bool isIndirect = false; 790 791 // If this is an indirect call, the callee attribute is missing. 792 FlatSymbolRefAttr calleeName = op.getCalleeAttr(); 793 if (!calleeName) { 794 isIndirect = true; 795 if (!op.getNumOperands()) 796 return op.emitOpError( 797 "must have either a `callee` attribute or at least an operand"); 798 auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>(); 799 if (!ptrType) 800 return op.emitOpError("indirect call expects a pointer as callee: ") 801 << ptrType; 802 fnType = ptrType.getElementType(); 803 } else { 804 Operation *callee = 805 SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr()); 806 if (!callee) 807 return op.emitOpError() 808 << "'" << calleeName.getValue() 809 << "' does not reference a symbol in the current scope"; 810 auto fn = dyn_cast<LLVMFuncOp>(callee); 811 if (!fn) 812 return op.emitOpError() << "'" << calleeName.getValue() 813 << "' does not reference a valid LLVM function"; 814 815 fnType = fn.getType(); 816 } 817 818 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 819 if (!funcType) 820 return op.emitOpError("callee does not have a functional type: ") << fnType; 821 822 // Verify that the operand and result types match the callee. 823 824 if (!funcType.isVarArg() && 825 funcType.getNumParams() != (op.getNumOperands() - isIndirect)) 826 return op.emitOpError() 827 << "incorrect number of operands (" 828 << (op.getNumOperands() - isIndirect) 829 << ") for callee (expecting: " << funcType.getNumParams() << ")"; 830 831 if (funcType.getNumParams() > (op.getNumOperands() - isIndirect)) 832 return op.emitOpError() << "incorrect number of operands (" 833 << (op.getNumOperands() - isIndirect) 834 << ") for varargs callee (expecting at least: " 835 << funcType.getNumParams() << ")"; 836 837 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 838 if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 839 return op.emitOpError() << "operand type mismatch for operand " << i 840 << ": " << op.getOperand(i + isIndirect).getType() 841 << " != " << funcType.getParamType(i); 842 843 if (op.getNumResults() == 0 && 844 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 845 return op.emitOpError() << "expected function call to produce a value"; 846 847 if (op.getNumResults() != 0 && 848 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 849 return op.emitOpError() 850 << "calling function with void result must not produce values"; 851 852 if (op.getNumResults() > 1) 853 return op.emitOpError() 854 << "expected LLVM function call to produce 0 or 1 result"; 855 856 if (op.getNumResults() && 857 op.getResult(0).getType() != funcType.getReturnType()) 858 return op.emitOpError() 859 << "result type mismatch: " << op.getResult(0).getType() 860 << " != " << funcType.getReturnType(); 861 862 return success(); 863 } 864 865 static void printCallOp(OpAsmPrinter &p, CallOp &op) { 866 auto callee = op.getCallee(); 867 bool isDirect = callee.hasValue(); 868 869 // Print the direct callee if present as a function attribute, or an indirect 870 // callee (first operand) otherwise. 871 p << ' '; 872 if (isDirect) 873 p.printSymbolName(callee.getValue()); 874 else 875 p << op.getOperand(0); 876 877 auto args = op.getOperands().drop_front(isDirect ? 0 : 1); 878 p << '(' << args << ')'; 879 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"}); 880 881 // Reconstruct the function MLIR function type from operand and result types. 882 p << " : " 883 << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes()); 884 } 885 886 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 887 // attribute-dict? `:` function-type 888 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { 889 SmallVector<OpAsmParser::OperandType, 8> operands; 890 Type type; 891 SymbolRefAttr funcAttr; 892 llvm::SMLoc trailingTypeLoc; 893 894 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 895 // case of an indirect call, there will be 1 operand before `(`. In case of a 896 // direct call, there will be no operands and the parser will stop at the 897 // function identifier without complaining. 898 if (parser.parseOperandList(operands)) 899 return failure(); 900 bool isDirect = operands.empty(); 901 902 // Optionally parse a function identifier. 903 if (isDirect) 904 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 905 return failure(); 906 907 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 908 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 909 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 910 return failure(); 911 912 auto funcType = type.dyn_cast<FunctionType>(); 913 if (!funcType) 914 return parser.emitError(trailingTypeLoc, "expected function type"); 915 if (funcType.getNumResults() > 1) 916 return parser.emitError(trailingTypeLoc, 917 "expected function with 0 or 1 result"); 918 if (isDirect) { 919 // Make sure types match. 920 if (parser.resolveOperands(operands, funcType.getInputs(), 921 parser.getNameLoc(), result.operands)) 922 return failure(); 923 if (funcType.getNumResults() != 0 && 924 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 925 result.addTypes(funcType.getResults()); 926 } else { 927 Builder &builder = parser.getBuilder(); 928 Type llvmResultType; 929 if (funcType.getNumResults() == 0) { 930 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 931 } else { 932 llvmResultType = funcType.getResult(0); 933 if (!isCompatibleType(llvmResultType)) 934 return parser.emitError(trailingTypeLoc, 935 "expected result to have LLVM type"); 936 } 937 938 SmallVector<Type, 8> argTypes; 939 argTypes.reserve(funcType.getNumInputs()); 940 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 941 auto argType = funcType.getInput(i); 942 if (!isCompatibleType(argType)) 943 return parser.emitError(trailingTypeLoc, 944 "expected LLVM types as inputs"); 945 argTypes.push_back(argType); 946 } 947 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 948 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 949 950 auto funcArguments = 951 ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); 952 953 // Make sure that the first operand (indirect callee) matches the wrapped 954 // LLVM IR function type, and that the types of the other call operands 955 // match the types of the function arguments. 956 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 957 parser.resolveOperands(funcArguments, funcType.getInputs(), 958 parser.getNameLoc(), result.operands)) 959 return failure(); 960 961 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 962 result.addTypes(llvmResultType); 963 } 964 965 return success(); 966 } 967 968 //===----------------------------------------------------------------------===// 969 // Printing/parsing for LLVM::ExtractElementOp. 970 //===----------------------------------------------------------------------===// 971 // Expects vector to be of wrapped LLVM vector type and position to be of 972 // wrapped LLVM i32 type. 973 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 974 Value vector, Value position, 975 ArrayRef<NamedAttribute> attrs) { 976 auto vectorType = vector.getType(); 977 auto llvmType = LLVM::getVectorElementType(vectorType); 978 build(b, result, llvmType, vector, position); 979 result.addAttributes(attrs); 980 } 981 982 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { 983 p << ' ' << op.getVector() << "[" << op.getPosition() << " : " 984 << op.getPosition().getType() << "]"; 985 p.printOptionalAttrDict(op->getAttrs()); 986 p << " : " << op.getVector().getType(); 987 } 988 989 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 990 // attribute-dict? `:` type 991 static ParseResult parseExtractElementOp(OpAsmParser &parser, 992 OperationState &result) { 993 llvm::SMLoc loc; 994 OpAsmParser::OperandType vector, position; 995 Type type, positionType; 996 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 997 parser.parseLSquare() || parser.parseOperand(position) || 998 parser.parseColonType(positionType) || parser.parseRSquare() || 999 parser.parseOptionalAttrDict(result.attributes) || 1000 parser.parseColonType(type) || 1001 parser.resolveOperand(vector, type, result.operands) || 1002 parser.resolveOperand(position, positionType, result.operands)) 1003 return failure(); 1004 if (!LLVM::isCompatibleVectorType(type)) 1005 return parser.emitError( 1006 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1007 result.addTypes(LLVM::getVectorElementType(type)); 1008 return success(); 1009 } 1010 1011 static LogicalResult verify(ExtractElementOp op) { 1012 Type vectorType = op.getVector().getType(); 1013 if (!LLVM::isCompatibleVectorType(vectorType)) 1014 return op->emitOpError("expected LLVM dialect-compatible vector type for " 1015 "operand #1, got") 1016 << vectorType; 1017 Type valueType = LLVM::getVectorElementType(vectorType); 1018 if (valueType != op.getRes().getType()) 1019 return op.emitOpError() << "Type mismatch: extracting from " << vectorType 1020 << " should produce " << valueType 1021 << " but this op returns " << op.getRes().getType(); 1022 return success(); 1023 } 1024 1025 //===----------------------------------------------------------------------===// 1026 // Printing/parsing for LLVM::ExtractValueOp. 1027 //===----------------------------------------------------------------------===// 1028 1029 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { 1030 p << ' ' << op.getContainer() << op.getPosition(); 1031 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 1032 p << " : " << op.getContainer().getType(); 1033 } 1034 1035 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1036 // `containerType`. Position is an integer array attribute where each value 1037 // is a zero-based position of the element in the aggregate type. Return the 1038 // resulting type wrapped in MLIR, or nullptr on error. 1039 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1040 Type containerType, 1041 ArrayAttr positionAttr, 1042 llvm::SMLoc attributeLoc, 1043 llvm::SMLoc typeLoc) { 1044 Type llvmType = containerType; 1045 if (!isCompatibleType(containerType)) 1046 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1047 1048 // Infer the element type from the structure type: iteratively step inside the 1049 // type by taking the element type, indexed by the position attribute for 1050 // structures. Check the position index before accessing, it is supposed to 1051 // be in bounds. 1052 for (Attribute subAttr : positionAttr) { 1053 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1054 if (!positionElementAttr) 1055 return parser.emitError(attributeLoc, 1056 "expected an array of integer literals"), 1057 nullptr; 1058 int position = positionElementAttr.getInt(); 1059 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1060 if (position < 0 || 1061 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1062 return parser.emitError(attributeLoc, "position out of bounds"), 1063 nullptr; 1064 llvmType = arrayType.getElementType(); 1065 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1066 if (position < 0 || 1067 static_cast<unsigned>(position) >= structType.getBody().size()) 1068 return parser.emitError(attributeLoc, "position out of bounds"), 1069 nullptr; 1070 llvmType = structType.getBody()[position]; 1071 } else { 1072 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1073 nullptr; 1074 } 1075 } 1076 return llvmType; 1077 } 1078 1079 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1080 // `containerType`. Returns null on failure. 1081 static Type getInsertExtractValueElementType(Type containerType, 1082 ArrayAttr positionAttr, 1083 Operation *op) { 1084 Type llvmType = containerType; 1085 if (!isCompatibleType(containerType)) { 1086 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1087 return {}; 1088 } 1089 1090 // Infer the element type from the structure type: iteratively step inside the 1091 // type by taking the element type, indexed by the position attribute for 1092 // structures. Check the position index before accessing, it is supposed to 1093 // be in bounds. 1094 for (Attribute subAttr : positionAttr) { 1095 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1096 if (!positionElementAttr) { 1097 op->emitOpError("expected an array of integer literals, got: ") 1098 << subAttr; 1099 return {}; 1100 } 1101 int position = positionElementAttr.getInt(); 1102 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1103 if (position < 0 || 1104 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1105 op->emitOpError("position out of bounds: ") << position; 1106 return {}; 1107 } 1108 llvmType = arrayType.getElementType(); 1109 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1110 if (position < 0 || 1111 static_cast<unsigned>(position) >= structType.getBody().size()) { 1112 op->emitOpError("position out of bounds") << position; 1113 return {}; 1114 } 1115 llvmType = structType.getBody()[position]; 1116 } else { 1117 op->emitOpError("expected LLVM IR structure/array type, got: ") 1118 << llvmType; 1119 return {}; 1120 } 1121 } 1122 return llvmType; 1123 } 1124 1125 // <operation> ::= `llvm.extractvalue` ssa-use 1126 // `[` integer-literal (`,` integer-literal)* `]` 1127 // attribute-dict? `:` type 1128 static ParseResult parseExtractValueOp(OpAsmParser &parser, 1129 OperationState &result) { 1130 OpAsmParser::OperandType container; 1131 Type containerType; 1132 ArrayAttr positionAttr; 1133 llvm::SMLoc attributeLoc, trailingTypeLoc; 1134 1135 if (parser.parseOperand(container) || 1136 parser.getCurrentLocation(&attributeLoc) || 1137 parser.parseAttribute(positionAttr, "position", result.attributes) || 1138 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1139 parser.getCurrentLocation(&trailingTypeLoc) || 1140 parser.parseType(containerType) || 1141 parser.resolveOperand(container, containerType, result.operands)) 1142 return failure(); 1143 1144 auto elementType = getInsertExtractValueElementType( 1145 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1146 if (!elementType) 1147 return failure(); 1148 1149 result.addTypes(elementType); 1150 return success(); 1151 } 1152 1153 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1154 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1155 while (insertValueOp) { 1156 if (getPosition() == insertValueOp.getPosition()) 1157 return insertValueOp.getValue(); 1158 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1159 } 1160 return {}; 1161 } 1162 1163 static LogicalResult verify(ExtractValueOp op) { 1164 Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), 1165 op.getPositionAttr(), op); 1166 if (!valueType) 1167 return failure(); 1168 1169 if (op.getRes().getType() != valueType) 1170 return op.emitOpError() 1171 << "Type mismatch: extracting from " << op.getContainer().getType() 1172 << " should produce " << valueType << " but this op returns " 1173 << op.getRes().getType(); 1174 return success(); 1175 } 1176 1177 //===----------------------------------------------------------------------===// 1178 // Printing/parsing for LLVM::InsertElementOp. 1179 //===----------------------------------------------------------------------===// 1180 1181 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { 1182 p << ' ' << op.getValue() << ", " << op.getVector() << "[" << op.getPosition() 1183 << " : " << op.getPosition().getType() << "]"; 1184 p.printOptionalAttrDict(op->getAttrs()); 1185 p << " : " << op.getVector().getType(); 1186 } 1187 1188 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1189 // attribute-dict? `:` type 1190 static ParseResult parseInsertElementOp(OpAsmParser &parser, 1191 OperationState &result) { 1192 llvm::SMLoc loc; 1193 OpAsmParser::OperandType vector, value, position; 1194 Type vectorType, positionType; 1195 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1196 parser.parseComma() || parser.parseOperand(vector) || 1197 parser.parseLSquare() || parser.parseOperand(position) || 1198 parser.parseColonType(positionType) || parser.parseRSquare() || 1199 parser.parseOptionalAttrDict(result.attributes) || 1200 parser.parseColonType(vectorType)) 1201 return failure(); 1202 1203 if (!LLVM::isCompatibleVectorType(vectorType)) 1204 return parser.emitError( 1205 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1206 Type valueType = LLVM::getVectorElementType(vectorType); 1207 if (!valueType) 1208 return failure(); 1209 1210 if (parser.resolveOperand(vector, vectorType, result.operands) || 1211 parser.resolveOperand(value, valueType, result.operands) || 1212 parser.resolveOperand(position, positionType, result.operands)) 1213 return failure(); 1214 1215 result.addTypes(vectorType); 1216 return success(); 1217 } 1218 1219 static LogicalResult verify(InsertElementOp op) { 1220 Type valueType = LLVM::getVectorElementType(op.getVector().getType()); 1221 if (valueType != op.getValue().getType()) 1222 return op.emitOpError() 1223 << "Type mismatch: cannot insert " << op.getValue().getType() 1224 << " into " << op.getVector().getType(); 1225 return success(); 1226 } 1227 //===----------------------------------------------------------------------===// 1228 // Printing/parsing for LLVM::InsertValueOp. 1229 //===----------------------------------------------------------------------===// 1230 1231 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { 1232 p << ' ' << op.getValue() << ", " << op.getContainer() << op.getPosition(); 1233 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 1234 p << " : " << op.getContainer().getType(); 1235 } 1236 1237 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1238 // `[` integer-literal (`,` integer-literal)* `]` 1239 // attribute-dict? `:` type 1240 static ParseResult parseInsertValueOp(OpAsmParser &parser, 1241 OperationState &result) { 1242 OpAsmParser::OperandType container, value; 1243 Type containerType; 1244 ArrayAttr positionAttr; 1245 llvm::SMLoc attributeLoc, trailingTypeLoc; 1246 1247 if (parser.parseOperand(value) || parser.parseComma() || 1248 parser.parseOperand(container) || 1249 parser.getCurrentLocation(&attributeLoc) || 1250 parser.parseAttribute(positionAttr, "position", result.attributes) || 1251 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1252 parser.getCurrentLocation(&trailingTypeLoc) || 1253 parser.parseType(containerType)) 1254 return failure(); 1255 1256 auto valueType = getInsertExtractValueElementType( 1257 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1258 if (!valueType) 1259 return failure(); 1260 1261 if (parser.resolveOperand(container, containerType, result.operands) || 1262 parser.resolveOperand(value, valueType, result.operands)) 1263 return failure(); 1264 1265 result.addTypes(containerType); 1266 return success(); 1267 } 1268 1269 static LogicalResult verify(InsertValueOp op) { 1270 Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), 1271 op.getPositionAttr(), op); 1272 if (!valueType) 1273 return failure(); 1274 1275 if (op.getValue().getType() != valueType) 1276 return op.emitOpError() 1277 << "Type mismatch: cannot insert " << op.getValue().getType() 1278 << " into " << op.getContainer().getType(); 1279 1280 return success(); 1281 } 1282 1283 //===----------------------------------------------------------------------===// 1284 // Printing, parsing and verification for LLVM::ReturnOp. 1285 //===----------------------------------------------------------------------===// 1286 1287 static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { 1288 p.printOptionalAttrDict(op->getAttrs()); 1289 assert(op.getNumOperands() <= 1); 1290 1291 if (op.getNumOperands() == 0) 1292 return; 1293 1294 p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); 1295 } 1296 1297 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` 1298 // type-list-no-parens 1299 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 1300 SmallVector<OpAsmParser::OperandType, 1> operands; 1301 Type type; 1302 1303 if (parser.parseOperandList(operands) || 1304 parser.parseOptionalAttrDict(result.attributes)) 1305 return failure(); 1306 if (operands.empty()) 1307 return success(); 1308 1309 if (parser.parseColonType(type) || 1310 parser.resolveOperand(operands[0], type, result.operands)) 1311 return failure(); 1312 return success(); 1313 } 1314 1315 static LogicalResult verify(ReturnOp op) { 1316 if (op->getNumOperands() > 1) 1317 return op->emitOpError("expected at most 1 operand"); 1318 1319 if (auto parent = op->getParentOfType<LLVMFuncOp>()) { 1320 Type expectedType = parent.getType().getReturnType(); 1321 if (expectedType.isa<LLVMVoidType>()) { 1322 if (op->getNumOperands() == 0) 1323 return success(); 1324 InFlightDiagnostic diag = op->emitOpError("expected no operands"); 1325 diag.attachNote(parent->getLoc()) << "when returning from function"; 1326 return diag; 1327 } 1328 if (op->getNumOperands() == 0) { 1329 if (expectedType.isa<LLVMVoidType>()) 1330 return success(); 1331 InFlightDiagnostic diag = op->emitOpError("expected 1 operand"); 1332 diag.attachNote(parent->getLoc()) << "when returning from function"; 1333 return diag; 1334 } 1335 if (expectedType != op->getOperand(0).getType()) { 1336 InFlightDiagnostic diag = op->emitOpError("mismatching result types"); 1337 diag.attachNote(parent->getLoc()) << "when returning from function"; 1338 return diag; 1339 } 1340 } 1341 return success(); 1342 } 1343 1344 //===----------------------------------------------------------------------===// 1345 // Verifier for LLVM::AddressOfOp. 1346 //===----------------------------------------------------------------------===// 1347 1348 template <typename OpTy> 1349 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1350 Operation *module = parent; 1351 while (module && !satisfiesLLVMModule(module)) 1352 module = module->getParentOp(); 1353 assert(module && "unexpected operation outside of a module"); 1354 return dyn_cast_or_null<OpTy>( 1355 mlir::SymbolTable::lookupSymbolIn(module, name)); 1356 } 1357 1358 GlobalOp AddressOfOp::getGlobal() { 1359 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1360 getGlobalName()); 1361 } 1362 1363 LLVMFuncOp AddressOfOp::getFunction() { 1364 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1365 getGlobalName()); 1366 } 1367 1368 static LogicalResult verify(AddressOfOp op) { 1369 auto global = op.getGlobal(); 1370 auto function = op.getFunction(); 1371 if (!global && !function) 1372 return op.emitOpError( 1373 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1374 1375 if (global && 1376 LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) != 1377 op.getResult().getType()) 1378 return op.emitOpError( 1379 "the type must be a pointer to the type of the referenced global"); 1380 1381 if (function && LLVM::LLVMPointerType::get(function.getType()) != 1382 op.getResult().getType()) 1383 return op.emitOpError( 1384 "the type must be a pointer to the type of the referenced function"); 1385 1386 return success(); 1387 } 1388 1389 //===----------------------------------------------------------------------===// 1390 // Builder, printer and verifier for LLVM::GlobalOp. 1391 //===----------------------------------------------------------------------===// 1392 1393 /// Returns the name used for the linkage attribute. This *must* correspond to 1394 /// the name of the attribute in ODS. 1395 static StringRef getLinkageAttrName() { return "linkage"; } 1396 1397 /// Returns the name used for the unnamed_addr attribute. This *must* correspond 1398 /// to the name of the attribute in ODS. 1399 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; } 1400 1401 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1402 bool isConstant, Linkage linkage, StringRef name, 1403 Attribute value, uint64_t alignment, unsigned addrSpace, 1404 bool dsoLocal, ArrayRef<NamedAttribute> attrs) { 1405 result.addAttribute(SymbolTable::getSymbolAttrName(), 1406 builder.getStringAttr(name)); 1407 result.addAttribute("global_type", TypeAttr::get(type)); 1408 if (isConstant) 1409 result.addAttribute("constant", builder.getUnitAttr()); 1410 if (value) 1411 result.addAttribute("value", value); 1412 if (dsoLocal) 1413 result.addAttribute("dso_local", builder.getUnitAttr()); 1414 1415 // Only add an alignment attribute if the "alignment" input 1416 // is different from 0. The value must also be a power of two, but 1417 // this is tested in GlobalOp::verify, not here. 1418 if (alignment != 0) 1419 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 1420 1421 result.addAttribute(::getLinkageAttrName(), 1422 LinkageAttr::get(builder.getContext(), linkage)); 1423 if (addrSpace != 0) 1424 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); 1425 result.attributes.append(attrs.begin(), attrs.end()); 1426 result.addRegion(); 1427 } 1428 1429 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { 1430 p << ' ' << stringifyLinkage(op.getLinkage()) << ' '; 1431 if (auto unnamedAddr = op.getUnnamedAddr()) { 1432 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1433 if (!str.empty()) 1434 p << str << ' '; 1435 } 1436 if (op.getConstant()) 1437 p << "constant "; 1438 p.printSymbolName(op.getSymName()); 1439 p << '('; 1440 if (auto value = op.getValueOrNull()) 1441 p.printAttribute(value); 1442 p << ')'; 1443 // Note that the alignment attribute is printed using the 1444 // default syntax here, even though it is an inherent attribute 1445 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1446 p.printOptionalAttrDict(op->getAttrs(), 1447 {SymbolTable::getSymbolAttrName(), "global_type", 1448 "constant", "value", getLinkageAttrName(), 1449 getUnnamedAddrAttrName()}); 1450 1451 // Print the trailing type unless it's a string global. 1452 if (op.getValueOrNull().dyn_cast_or_null<StringAttr>()) 1453 return; 1454 p << " : " << op.getType(); 1455 1456 Region &initializer = op.getInitializerRegion(); 1457 if (!initializer.empty()) 1458 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1459 } 1460 1461 // Parses one of the keywords provided in the list `keywords` and returns the 1462 // position of the parsed keyword in the list. If none of the keywords from the 1463 // list is parsed, returns -1. 1464 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1465 ArrayRef<StringRef> keywords) { 1466 for (auto en : llvm::enumerate(keywords)) { 1467 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1468 return en.index(); 1469 } 1470 return -1; 1471 } 1472 1473 namespace { 1474 template <typename Ty> 1475 struct EnumTraits {}; 1476 1477 #define REGISTER_ENUM_TYPE(Ty) \ 1478 template <> \ 1479 struct EnumTraits<Ty> { \ 1480 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1481 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1482 } 1483 1484 REGISTER_ENUM_TYPE(Linkage); 1485 REGISTER_ENUM_TYPE(UnnamedAddr); 1486 } // end namespace 1487 1488 /// Parse an enum from the keyword, or default to the provided default value. 1489 /// The return type is the enum type by default, unless overriden with the 1490 /// second template argument. 1491 template <typename EnumTy, typename RetTy = EnumTy> 1492 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1493 OperationState &result, 1494 EnumTy defaultValue) { 1495 SmallVector<StringRef, 10> names; 1496 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1497 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1498 1499 int index = parseOptionalKeywordAlternative(parser, names); 1500 if (index == -1) 1501 return static_cast<RetTy>(defaultValue); 1502 return static_cast<RetTy>(index); 1503 } 1504 1505 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1506 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1507 // align ::= `align` `=` UINT64 1508 // 1509 // The type can be omitted for string attributes, in which case it will be 1510 // inferred from the value of the string as [strlen(value) x i8]. 1511 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 1512 MLIRContext *ctx = parser.getContext(); 1513 // Parse optional linkage, default to External. 1514 result.addAttribute(getLinkageAttrName(), 1515 LLVM::LinkageAttr::get( 1516 ctx, parseOptionalLLVMKeyword<Linkage>( 1517 parser, result, LLVM::Linkage::External))); 1518 // Parse optional UnnamedAddr, default to None. 1519 result.addAttribute(getUnnamedAddrAttrName(), 1520 parser.getBuilder().getI64IntegerAttr( 1521 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1522 parser, result, LLVM::UnnamedAddr::None))); 1523 1524 if (succeeded(parser.parseOptionalKeyword("constant"))) 1525 result.addAttribute("constant", parser.getBuilder().getUnitAttr()); 1526 1527 StringAttr name; 1528 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), 1529 result.attributes) || 1530 parser.parseLParen()) 1531 return failure(); 1532 1533 Attribute value; 1534 if (parser.parseOptionalRParen()) { 1535 if (parser.parseAttribute(value, "value", result.attributes) || 1536 parser.parseRParen()) 1537 return failure(); 1538 } 1539 1540 SmallVector<Type, 1> types; 1541 if (parser.parseOptionalAttrDict(result.attributes) || 1542 parser.parseOptionalColonTypeList(types)) 1543 return failure(); 1544 1545 if (types.size() > 1) 1546 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1547 1548 Region &initRegion = *result.addRegion(); 1549 if (types.empty()) { 1550 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1551 MLIRContext *context = parser.getContext(); 1552 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1553 strAttr.getValue().size()); 1554 types.push_back(arrayType); 1555 } else { 1556 return parser.emitError(parser.getNameLoc(), 1557 "type can only be omitted for string globals"); 1558 } 1559 } else { 1560 OptionalParseResult parseResult = 1561 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1562 /*argTypes=*/{}); 1563 if (parseResult.hasValue() && failed(*parseResult)) 1564 return failure(); 1565 } 1566 1567 result.addAttribute("global_type", TypeAttr::get(types[0])); 1568 return success(); 1569 } 1570 1571 static bool isZeroAttribute(Attribute value) { 1572 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1573 return intValue.getValue().isNullValue(); 1574 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1575 return fpValue.getValue().isZero(); 1576 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1577 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1578 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1579 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1580 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1581 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1582 return false; 1583 } 1584 1585 static LogicalResult verify(GlobalOp op) { 1586 if (!LLVMPointerType::isValidElementType(op.getType())) 1587 return op.emitOpError( 1588 "expects type to be a valid element type for an LLVM pointer"); 1589 if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp())) 1590 return op.emitOpError("must appear at the module level"); 1591 1592 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1593 auto type = op.getType().dyn_cast<LLVMArrayType>(); 1594 IntegerType elementType = 1595 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1596 if (!elementType || elementType.getWidth() != 8 || 1597 type.getNumElements() != strAttr.getValue().size()) 1598 return op.emitOpError( 1599 "requires an i8 array type of the length equal to that of the string " 1600 "attribute"); 1601 } 1602 1603 if (Block *b = op.getInitializerBlock()) { 1604 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1605 if (ret.operand_type_begin() == ret.operand_type_end()) 1606 return op.emitOpError("initializer region cannot return void"); 1607 if (*ret.operand_type_begin() != op.getType()) 1608 return op.emitOpError("initializer region type ") 1609 << *ret.operand_type_begin() << " does not match global type " 1610 << op.getType(); 1611 1612 if (op.getValueOrNull()) 1613 return op.emitOpError("cannot have both initializer value and region"); 1614 } 1615 1616 if (op.getLinkage() == Linkage::Common) { 1617 if (Attribute value = op.getValueOrNull()) { 1618 if (!isZeroAttribute(value)) { 1619 return op.emitOpError() 1620 << "expected zero value for '" 1621 << stringifyLinkage(Linkage::Common) << "' linkage"; 1622 } 1623 } 1624 } 1625 1626 if (op.getLinkage() == Linkage::Appending) { 1627 if (!op.getType().isa<LLVMArrayType>()) { 1628 return op.emitOpError() 1629 << "expected array type for '" 1630 << stringifyLinkage(Linkage::Appending) << "' linkage"; 1631 } 1632 } 1633 1634 Optional<uint64_t> alignAttr = op.getAlignment(); 1635 if (alignAttr.hasValue()) { 1636 uint64_t value = alignAttr.getValue(); 1637 if (!llvm::isPowerOf2_64(value)) 1638 return op->emitError() << "alignment attribute is not a power of 2"; 1639 } 1640 1641 return success(); 1642 } 1643 1644 //===----------------------------------------------------------------------===// 1645 // LLVM::GlobalCtorsOp 1646 //===----------------------------------------------------------------------===// 1647 1648 LogicalResult 1649 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1650 for (Attribute ctor : ctors()) { 1651 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 1652 symbolTable))) 1653 return failure(); 1654 } 1655 return success(); 1656 } 1657 1658 static LogicalResult verify(GlobalCtorsOp op) { 1659 if (op.ctors().size() != op.priorities().size()) 1660 return op.emitError( 1661 "mismatch between the number of ctors and the number of priorities"); 1662 return success(); 1663 } 1664 1665 //===----------------------------------------------------------------------===// 1666 // LLVM::GlobalDtorsOp 1667 //===----------------------------------------------------------------------===// 1668 1669 LogicalResult 1670 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1671 for (Attribute dtor : dtors()) { 1672 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 1673 symbolTable))) 1674 return failure(); 1675 } 1676 return success(); 1677 } 1678 1679 static LogicalResult verify(GlobalDtorsOp op) { 1680 if (op.dtors().size() != op.priorities().size()) 1681 return op.emitError( 1682 "mismatch between the number of dtors and the number of priorities"); 1683 return success(); 1684 } 1685 1686 //===----------------------------------------------------------------------===// 1687 // Printing/parsing for LLVM::ShuffleVectorOp. 1688 //===----------------------------------------------------------------------===// 1689 // Expects vector to be of wrapped LLVM vector type and position to be of 1690 // wrapped LLVM i32 type. 1691 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 1692 Value v1, Value v2, ArrayAttr mask, 1693 ArrayRef<NamedAttribute> attrs) { 1694 auto containerType = v1.getType(); 1695 auto vType = LLVM::getFixedVectorType( 1696 LLVM::getVectorElementType(containerType), mask.size()); 1697 build(b, result, vType, v1, v2, mask); 1698 result.addAttributes(attrs); 1699 } 1700 1701 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { 1702 p << ' ' << op.getV1() << ", " << op.getV2() << " " << op.getMask(); 1703 p.printOptionalAttrDict(op->getAttrs(), {"mask"}); 1704 p << " : " << op.getV1().getType() << ", " << op.getV2().getType(); 1705 } 1706 1707 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 1708 // `[` integer-literal (`,` integer-literal)* `]` 1709 // attribute-dict? `:` type 1710 static ParseResult parseShuffleVectorOp(OpAsmParser &parser, 1711 OperationState &result) { 1712 llvm::SMLoc loc; 1713 OpAsmParser::OperandType v1, v2; 1714 ArrayAttr maskAttr; 1715 Type typeV1, typeV2; 1716 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 1717 parser.parseComma() || parser.parseOperand(v2) || 1718 parser.parseAttribute(maskAttr, "mask", result.attributes) || 1719 parser.parseOptionalAttrDict(result.attributes) || 1720 parser.parseColonType(typeV1) || parser.parseComma() || 1721 parser.parseType(typeV2) || 1722 parser.resolveOperand(v1, typeV1, result.operands) || 1723 parser.resolveOperand(v2, typeV2, result.operands)) 1724 return failure(); 1725 if (!LLVM::isCompatibleVectorType(typeV1)) 1726 return parser.emitError( 1727 loc, "expected LLVM IR dialect vector type for operand #1"); 1728 auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1), 1729 maskAttr.size()); 1730 result.addTypes(vType); 1731 return success(); 1732 } 1733 1734 //===----------------------------------------------------------------------===// 1735 // Implementations for LLVM::LLVMFuncOp. 1736 //===----------------------------------------------------------------------===// 1737 1738 // Add the entry block to the function. 1739 Block *LLVMFuncOp::addEntryBlock() { 1740 assert(empty() && "function already has an entry block"); 1741 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 1742 1743 auto *entry = new Block; 1744 push_back(entry); 1745 1746 LLVMFunctionType type = getType(); 1747 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 1748 entry->addArgument(type.getParamType(i)); 1749 return entry; 1750 } 1751 1752 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 1753 StringRef name, Type type, LLVM::Linkage linkage, 1754 bool dsoLocal, ArrayRef<NamedAttribute> attrs, 1755 ArrayRef<DictionaryAttr> argAttrs) { 1756 result.addRegion(); 1757 result.addAttribute(SymbolTable::getSymbolAttrName(), 1758 builder.getStringAttr(name)); 1759 result.addAttribute("type", TypeAttr::get(type)); 1760 result.addAttribute(::getLinkageAttrName(), 1761 LinkageAttr::get(builder.getContext(), linkage)); 1762 result.attributes.append(attrs.begin(), attrs.end()); 1763 if (dsoLocal) 1764 result.addAttribute("dso_local", builder.getUnitAttr()); 1765 if (argAttrs.empty()) 1766 return; 1767 1768 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 1769 "expected as many argument attribute lists as arguments"); 1770 function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, 1771 /*resultAttrs=*/llvm::None); 1772 } 1773 1774 // Builds an LLVM function type from the given lists of input and output types. 1775 // Returns a null type if any of the types provided are non-LLVM types, or if 1776 // there is more than one output type. 1777 static Type 1778 buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, 1779 ArrayRef<Type> inputs, ArrayRef<Type> outputs, 1780 function_like_impl::VariadicFlag variadicFlag) { 1781 Builder &b = parser.getBuilder(); 1782 if (outputs.size() > 1) { 1783 parser.emitError(loc, "failed to construct function type: expected zero or " 1784 "one function result"); 1785 return {}; 1786 } 1787 1788 // Convert inputs to LLVM types, exit early on error. 1789 SmallVector<Type, 4> llvmInputs; 1790 for (auto t : inputs) { 1791 if (!isCompatibleType(t)) { 1792 parser.emitError(loc, "failed to construct function type: expected LLVM " 1793 "type for function arguments"); 1794 return {}; 1795 } 1796 llvmInputs.push_back(t); 1797 } 1798 1799 // No output is denoted as "void" in LLVM type system. 1800 Type llvmOutput = 1801 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 1802 if (!isCompatibleType(llvmOutput)) { 1803 parser.emitError(loc, "failed to construct function type: expected LLVM " 1804 "type for function results") 1805 << llvmOutput; 1806 return {}; 1807 } 1808 return LLVMFunctionType::get(llvmOutput, llvmInputs, 1809 variadicFlag.isVariadic()); 1810 } 1811 1812 // Parses an LLVM function. 1813 // 1814 // operation ::= `llvm.func` linkage? function-signature function-attributes? 1815 // function-body 1816 // 1817 static ParseResult parseLLVMFuncOp(OpAsmParser &parser, 1818 OperationState &result) { 1819 // Default to external linkage if no keyword is provided. 1820 result.addAttribute( 1821 getLinkageAttrName(), 1822 LinkageAttr::get(parser.getContext(), 1823 parseOptionalLLVMKeyword<Linkage>( 1824 parser, result, LLVM::Linkage::External))); 1825 1826 StringAttr nameAttr; 1827 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 1828 SmallVector<NamedAttrList, 1> argAttrs; 1829 SmallVector<NamedAttrList, 1> resultAttrs; 1830 SmallVector<Type, 8> argTypes; 1831 SmallVector<Type, 4> resultTypes; 1832 bool isVariadic; 1833 1834 auto signatureLocation = parser.getCurrentLocation(); 1835 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1836 result.attributes) || 1837 function_like_impl::parseFunctionSignature( 1838 parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, 1839 isVariadic, resultTypes, resultAttrs)) 1840 return failure(); 1841 1842 auto type = 1843 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 1844 function_like_impl::VariadicFlag(isVariadic)); 1845 if (!type) 1846 return failure(); 1847 result.addAttribute(function_like_impl::getTypeAttrName(), 1848 TypeAttr::get(type)); 1849 1850 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 1851 return failure(); 1852 function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result, 1853 argAttrs, resultAttrs); 1854 1855 auto *body = result.addRegion(); 1856 OptionalParseResult parseResult = parser.parseOptionalRegion( 1857 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 1858 return failure(parseResult.hasValue() && failed(*parseResult)); 1859 } 1860 1861 // Print the LLVMFuncOp. Collects argument and result types and passes them to 1862 // helper functions. Drops "void" result since it cannot be parsed back. Skips 1863 // the external linkage since it is the default value. 1864 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { 1865 p << ' '; 1866 if (op.getLinkage() != LLVM::Linkage::External) 1867 p << stringifyLinkage(op.getLinkage()) << ' '; 1868 p.printSymbolName(op.getName()); 1869 1870 LLVMFunctionType fnType = op.getType(); 1871 SmallVector<Type, 8> argTypes; 1872 SmallVector<Type, 1> resTypes; 1873 argTypes.reserve(fnType.getNumParams()); 1874 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 1875 argTypes.push_back(fnType.getParamType(i)); 1876 1877 Type returnType = fnType.getReturnType(); 1878 if (!returnType.isa<LLVMVoidType>()) 1879 resTypes.push_back(returnType); 1880 1881 function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), 1882 resTypes); 1883 function_like_impl::printFunctionAttributes( 1884 p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); 1885 1886 // Print the body if this is not an external function. 1887 Region &body = op.getBody(); 1888 if (!body.empty()) 1889 p.printRegion(body, /*printEntryBlockArgs=*/false, 1890 /*printBlockTerminators=*/true); 1891 } 1892 1893 // Hook for OpTrait::FunctionLike, called after verifying that the 'type' 1894 // attribute is present. This can check for preconditions of the 1895 // getNumArguments hook not failing. 1896 LogicalResult LLVMFuncOp::verifyType() { 1897 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>(); 1898 if (!llvmType) 1899 return emitOpError("requires '" + getTypeAttrName() + 1900 "' attribute of wrapped LLVM function type"); 1901 1902 return success(); 1903 } 1904 1905 // Hook for OpTrait::FunctionLike, returns the number of function arguments. 1906 // Depends on the type attribute being correct as checked by verifyType 1907 unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); } 1908 1909 // Hook for OpTrait::FunctionLike, returns the number of function results. 1910 // Depends on the type attribute being correct as checked by verifyType 1911 unsigned LLVMFuncOp::getNumFuncResults() { 1912 // We model LLVM functions that return void as having zero results, 1913 // and all others as having one result. 1914 // If we modeled a void return as one result, then it would be possible to 1915 // attach an MLIR result attribute to it, and it isn't clear what semantics we 1916 // would assign to that. 1917 if (getType().getReturnType().isa<LLVMVoidType>()) 1918 return 0; 1919 return 1; 1920 } 1921 1922 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 1923 // - functions don't have 'common' linkage 1924 // - external functions have 'external' or 'extern_weak' linkage; 1925 // - vararg is (currently) only supported for external functions; 1926 // - entry block arguments are of LLVM types and match the function signature. 1927 static LogicalResult verify(LLVMFuncOp op) { 1928 if (op.getLinkage() == LLVM::Linkage::Common) 1929 return op.emitOpError() 1930 << "functions cannot have '" 1931 << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; 1932 1933 if (op.isExternal()) { 1934 if (op.getLinkage() != LLVM::Linkage::External && 1935 op.getLinkage() != LLVM::Linkage::ExternWeak) 1936 return op.emitOpError() 1937 << "external functions must have '" 1938 << stringifyLinkage(LLVM::Linkage::External) << "' or '" 1939 << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; 1940 return success(); 1941 } 1942 1943 if (op.isVarArg()) 1944 return op.emitOpError("only external functions can be variadic"); 1945 1946 unsigned numArguments = op.getType().getNumParams(); 1947 Block &entryBlock = op.front(); 1948 for (unsigned i = 0; i < numArguments; ++i) { 1949 Type argType = entryBlock.getArgument(i).getType(); 1950 if (!isCompatibleType(argType)) 1951 return op.emitOpError("entry block argument #") 1952 << i << " is not of LLVM type"; 1953 if (op.getType().getParamType(i) != argType) 1954 return op.emitOpError("the type of entry block argument #") 1955 << i << " does not match the function signature"; 1956 } 1957 1958 return success(); 1959 } 1960 1961 //===----------------------------------------------------------------------===// 1962 // Verification for LLVM::ConstantOp. 1963 //===----------------------------------------------------------------------===// 1964 1965 static LogicalResult verify(LLVM::ConstantOp op) { 1966 if (StringAttr sAttr = op.getValue().dyn_cast<StringAttr>()) { 1967 auto arrayType = op.getType().dyn_cast<LLVMArrayType>(); 1968 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 1969 !arrayType.getElementType().isInteger(8)) { 1970 return op->emitOpError() 1971 << "expected array type of " << sAttr.getValue().size() 1972 << " i8 elements for the string constant"; 1973 } 1974 return success(); 1975 } 1976 if (auto structType = op.getType().dyn_cast<LLVMStructType>()) { 1977 if (structType.getBody().size() != 2 || 1978 structType.getBody()[0] != structType.getBody()[1]) { 1979 return op.emitError() << "expected struct type with two elements of the " 1980 "same type, the type of a complex constant"; 1981 } 1982 1983 auto arrayAttr = op.getValue().dyn_cast<ArrayAttr>(); 1984 if (!arrayAttr || arrayAttr.size() != 2 || 1985 arrayAttr[0].getType() != arrayAttr[1].getType()) { 1986 return op.emitOpError() << "expected array attribute with two elements, " 1987 "representing a complex constant"; 1988 } 1989 1990 Type elementType = structType.getBody()[0]; 1991 if (!elementType 1992 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 1993 return op.emitError() 1994 << "expected struct element types to be floating point type or " 1995 "integer type"; 1996 } 1997 return success(); 1998 } 1999 if (!op.getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2000 return op.emitOpError() 2001 << "only supports integer, float, string or elements attributes"; 2002 return success(); 2003 } 2004 2005 //===----------------------------------------------------------------------===// 2006 // Utility functions for parsing atomic ops 2007 //===----------------------------------------------------------------------===// 2008 2009 // Helper function to parse a keyword into the specified attribute named by 2010 // `attrName`. The keyword must match one of the string values defined by the 2011 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2012 // state. 2013 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2014 StringRef attrName) { 2015 llvm::SMLoc loc; 2016 StringRef keyword; 2017 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2018 return failure(); 2019 2020 // Replace the keyword `keyword` with an integer attribute. 2021 auto kind = symbolizeAtomicBinOp(keyword); 2022 if (!kind) { 2023 return parser.emitError(loc) 2024 << "'" << keyword << "' is an incorrect value of the '" << attrName 2025 << "' attribute"; 2026 } 2027 2028 auto value = static_cast<int64_t>(kind.getValue()); 2029 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2030 result.addAttribute(attrName, attr); 2031 2032 return success(); 2033 } 2034 2035 // Helper function to parse a keyword into the specified attribute named by 2036 // `attrName`. The keyword must match one of the string values defined by the 2037 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2038 // state. 2039 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2040 OperationState &result, 2041 StringRef attrName) { 2042 llvm::SMLoc loc; 2043 StringRef ordering; 2044 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2045 return failure(); 2046 2047 // Replace the keyword `ordering` with an integer attribute. 2048 auto kind = symbolizeAtomicOrdering(ordering); 2049 if (!kind) { 2050 return parser.emitError(loc) 2051 << "'" << ordering << "' is an incorrect value of the '" << attrName 2052 << "' attribute"; 2053 } 2054 2055 auto value = static_cast<int64_t>(kind.getValue()); 2056 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2057 result.addAttribute(attrName, attr); 2058 2059 return success(); 2060 } 2061 2062 //===----------------------------------------------------------------------===// 2063 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2064 //===----------------------------------------------------------------------===// 2065 2066 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { 2067 p << ' ' << stringifyAtomicBinOp(op.getBinOp()) << ' ' << op.getPtr() << ", " 2068 << op.getVal() << ' ' << stringifyAtomicOrdering(op.getOrdering()) << ' '; 2069 p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"}); 2070 p << " : " << op.getRes().getType(); 2071 } 2072 2073 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2074 // attribute-dict? `:` type 2075 static ParseResult parseAtomicRMWOp(OpAsmParser &parser, 2076 OperationState &result) { 2077 Type type; 2078 OpAsmParser::OperandType ptr, val; 2079 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2080 parser.parseComma() || parser.parseOperand(val) || 2081 parseAtomicOrdering(parser, result, "ordering") || 2082 parser.parseOptionalAttrDict(result.attributes) || 2083 parser.parseColonType(type) || 2084 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2085 result.operands) || 2086 parser.resolveOperand(val, type, result.operands)) 2087 return failure(); 2088 2089 result.addTypes(type); 2090 return success(); 2091 } 2092 2093 static LogicalResult verify(AtomicRMWOp op) { 2094 auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>(); 2095 auto valType = op.getVal().getType(); 2096 if (valType != ptrType.getElementType()) 2097 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2098 "match type for operand #1"); 2099 auto resType = op.getRes().getType(); 2100 if (resType != valType) 2101 return op.emitOpError( 2102 "expected LLVM IR result type to match type for operand #1"); 2103 if (op.getBinOp() == AtomicBinOp::fadd || 2104 op.getBinOp() == AtomicBinOp::fsub) { 2105 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2106 return op.emitOpError("expected LLVM IR floating point type"); 2107 } else if (op.getBinOp() == AtomicBinOp::xchg) { 2108 auto intType = valType.dyn_cast<IntegerType>(); 2109 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2110 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2111 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2112 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2113 !valType.isa<Float64Type>()) 2114 return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2115 } else { 2116 auto intType = valType.dyn_cast<IntegerType>(); 2117 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2118 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2119 intBitWidth != 64) 2120 return op.emitOpError("expected LLVM IR integer type"); 2121 } 2122 2123 if (static_cast<unsigned>(op.getOrdering()) < 2124 static_cast<unsigned>(AtomicOrdering::monotonic)) 2125 return op.emitOpError() 2126 << "expected at least '" 2127 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2128 << "' ordering"; 2129 2130 return success(); 2131 } 2132 2133 //===----------------------------------------------------------------------===// 2134 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2135 //===----------------------------------------------------------------------===// 2136 2137 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { 2138 p << ' ' << op.getPtr() << ", " << op.getCmp() << ", " << op.getVal() << ' ' 2139 << stringifyAtomicOrdering(op.getSuccessOrdering()) << ' ' 2140 << stringifyAtomicOrdering(op.getFailureOrdering()); 2141 p.printOptionalAttrDict(op->getAttrs(), 2142 {"success_ordering", "failure_ordering"}); 2143 p << " : " << op.getVal().getType(); 2144 } 2145 2146 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2147 // keyword keyword attribute-dict? `:` type 2148 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, 2149 OperationState &result) { 2150 auto &builder = parser.getBuilder(); 2151 Type type; 2152 OpAsmParser::OperandType ptr, cmp, val; 2153 if (parser.parseOperand(ptr) || parser.parseComma() || 2154 parser.parseOperand(cmp) || parser.parseComma() || 2155 parser.parseOperand(val) || 2156 parseAtomicOrdering(parser, result, "success_ordering") || 2157 parseAtomicOrdering(parser, result, "failure_ordering") || 2158 parser.parseOptionalAttrDict(result.attributes) || 2159 parser.parseColonType(type) || 2160 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2161 result.operands) || 2162 parser.resolveOperand(cmp, type, result.operands) || 2163 parser.resolveOperand(val, type, result.operands)) 2164 return failure(); 2165 2166 auto boolType = IntegerType::get(builder.getContext(), 1); 2167 auto resultType = 2168 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2169 result.addTypes(resultType); 2170 2171 return success(); 2172 } 2173 2174 static LogicalResult verify(AtomicCmpXchgOp op) { 2175 auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>(); 2176 if (!ptrType) 2177 return op.emitOpError("expected LLVM IR pointer type for operand #0"); 2178 auto cmpType = op.getCmp().getType(); 2179 auto valType = op.getVal().getType(); 2180 if (cmpType != ptrType.getElementType() || cmpType != valType) 2181 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2182 "match type for all other operands"); 2183 auto intType = valType.dyn_cast<IntegerType>(); 2184 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2185 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2186 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2187 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2188 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2189 return op.emitOpError("unexpected LLVM IR type"); 2190 if (op.getSuccessOrdering() < AtomicOrdering::monotonic || 2191 op.getFailureOrdering() < AtomicOrdering::monotonic) 2192 return op.emitOpError("ordering must be at least 'monotonic'"); 2193 if (op.getFailureOrdering() == AtomicOrdering::release || 2194 op.getFailureOrdering() == AtomicOrdering::acq_rel) 2195 return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2196 return success(); 2197 } 2198 2199 //===----------------------------------------------------------------------===// 2200 // Printer, parser and verifier for LLVM::FenceOp. 2201 //===----------------------------------------------------------------------===// 2202 2203 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2204 // attribute-dict? 2205 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) { 2206 StringAttr sScope; 2207 StringRef syncscopeKeyword = "syncscope"; 2208 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2209 if (parser.parseLParen() || 2210 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2211 parser.parseRParen()) 2212 return failure(); 2213 } else { 2214 result.addAttribute(syncscopeKeyword, 2215 parser.getBuilder().getStringAttr("")); 2216 } 2217 if (parseAtomicOrdering(parser, result, "ordering") || 2218 parser.parseOptionalAttrDict(result.attributes)) 2219 return failure(); 2220 return success(); 2221 } 2222 2223 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { 2224 StringRef syncscopeKeyword = "syncscope"; 2225 p << ' '; 2226 if (!op->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2227 p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") "; 2228 p << stringifyAtomicOrdering(op.getOrdering()); 2229 } 2230 2231 static LogicalResult verify(FenceOp &op) { 2232 if (op.getOrdering() == AtomicOrdering::not_atomic || 2233 op.getOrdering() == AtomicOrdering::unordered || 2234 op.getOrdering() == AtomicOrdering::monotonic) 2235 return op.emitOpError("can be given only acquire, release, acq_rel, " 2236 "and seq_cst orderings"); 2237 return success(); 2238 } 2239 2240 //===----------------------------------------------------------------------===// 2241 // LLVMDialect initialization, type parsing, and registration. 2242 //===----------------------------------------------------------------------===// 2243 2244 void LLVMDialect::initialize() { 2245 addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>(); 2246 2247 // clang-format off 2248 addTypes<LLVMVoidType, 2249 LLVMPPCFP128Type, 2250 LLVMX86MMXType, 2251 LLVMTokenType, 2252 LLVMLabelType, 2253 LLVMMetadataType, 2254 LLVMFunctionType, 2255 LLVMPointerType, 2256 LLVMFixedVectorType, 2257 LLVMScalableVectorType, 2258 LLVMArrayType, 2259 LLVMStructType>(); 2260 // clang-format on 2261 addOperations< 2262 #define GET_OP_LIST 2263 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2264 >(); 2265 2266 // Support unknown operations because not all LLVM operations are registered. 2267 allowUnknownOperations(); 2268 } 2269 2270 #define GET_OP_CLASSES 2271 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2272 2273 /// Parse a type registered to this dialect. 2274 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2275 return detail::parseType(parser); 2276 } 2277 2278 /// Print a type registered to this dialect. 2279 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2280 return detail::printType(type, os); 2281 } 2282 2283 LogicalResult LLVMDialect::verifyDataLayoutString( 2284 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2285 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2286 llvm::DataLayout::parse(descr); 2287 if (maybeDataLayout) 2288 return success(); 2289 2290 std::string message; 2291 llvm::raw_string_ostream messageStream(message); 2292 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2293 reportError("invalid data layout descriptor: " + messageStream.str()); 2294 return failure(); 2295 } 2296 2297 /// Verify LLVM dialect attributes. 2298 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2299 NamedAttribute attr) { 2300 // If the `llvm.loop` attribute is present, enforce the following structure, 2301 // which the module translation can assume. 2302 if (attr.first.strref() == LLVMDialect::getLoopAttrName()) { 2303 auto loopAttr = attr.second.dyn_cast<DictionaryAttr>(); 2304 if (!loopAttr) 2305 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2306 << "' to be a dictionary attribute"; 2307 Optional<NamedAttribute> parallelAccessGroup = 2308 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2309 if (parallelAccessGroup.hasValue()) { 2310 auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>(); 2311 if (!accessGroups) 2312 return op->emitOpError() 2313 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2314 << "' to be an array attribute"; 2315 for (Attribute attr : accessGroups) { 2316 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2317 if (!accessGroupRef) 2318 return op->emitOpError() 2319 << "expected '" << attr << "' to be a symbol reference"; 2320 StringAttr metadataName = accessGroupRef.getRootReference(); 2321 auto metadataOp = 2322 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2323 op->getParentOp(), metadataName); 2324 if (!metadataOp) 2325 return op->emitOpError() 2326 << "expected '" << attr << "' to reference a metadata op"; 2327 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2328 Operation *accessGroupOp = 2329 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2330 if (!accessGroupOp) 2331 return op->emitOpError() 2332 << "expected '" << attr << "' to reference an access_group op"; 2333 } 2334 } 2335 2336 Optional<NamedAttribute> loopOptions = 2337 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2338 if (loopOptions.hasValue() && !loopOptions->second.isa<LoopOptionsAttr>()) 2339 return op->emitOpError() 2340 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2341 << "' to be a `loopopts` attribute"; 2342 } 2343 2344 // If the data layout attribute is present, it must use the LLVM data layout 2345 // syntax. Try parsing it and report errors in case of failure. Users of this 2346 // attribute may assume it is well-formed and can pass it to the (asserting) 2347 // llvm::DataLayout constructor. 2348 if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2349 return success(); 2350 if (auto stringAttr = attr.second.dyn_cast<StringAttr>()) 2351 return verifyDataLayoutString( 2352 stringAttr.getValue(), 2353 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2354 2355 return op->emitOpError() << "expected '" 2356 << LLVM::LLVMDialect::getDataLayoutAttrName() 2357 << "' to be a string attribute"; 2358 } 2359 2360 /// Verify LLVMIR function argument attributes. 2361 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2362 unsigned regionIdx, 2363 unsigned argIdx, 2364 NamedAttribute argAttr) { 2365 // Check that llvm.noalias is a unit attribute. 2366 if (argAttr.first == LLVMDialect::getNoAliasAttrName() && 2367 !argAttr.second.isa<UnitAttr>()) 2368 return op->emitError() 2369 << "expected llvm.noalias argument attribute to be a unit attribute"; 2370 // Check that llvm.align is an integer attribute. 2371 if (argAttr.first == LLVMDialect::getAlignAttrName() && 2372 !argAttr.second.isa<IntegerAttr>()) 2373 return op->emitError() 2374 << "llvm.align argument attribute of non integer type"; 2375 return success(); 2376 } 2377 2378 //===----------------------------------------------------------------------===// 2379 // Utility functions. 2380 //===----------------------------------------------------------------------===// 2381 2382 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2383 StringRef name, StringRef value, 2384 LLVM::Linkage linkage) { 2385 assert(builder.getInsertionBlock() && 2386 builder.getInsertionBlock()->getParentOp() && 2387 "expected builder to point to a block constrained in an op"); 2388 auto module = 2389 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2390 assert(module && "builder points to an op outside of a module"); 2391 2392 // Create the global at the entry of the module. 2393 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2394 MLIRContext *ctx = builder.getContext(); 2395 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2396 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2397 loc, type, /*isConstant=*/true, linkage, name, 2398 builder.getStringAttr(value), /*alignment=*/0); 2399 2400 // Get the pointer to the first character in the global string. 2401 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2402 Value cst0 = builder.create<LLVM::ConstantOp>( 2403 loc, IntegerType::get(ctx, 64), 2404 builder.getIntegerAttr(builder.getIndexType(), 0)); 2405 return builder.create<LLVM::GEPOp>( 2406 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2407 ValueRange{cst0, cst0}); 2408 } 2409 2410 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2411 return op->hasTrait<OpTrait::SymbolTable>() && 2412 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2413 } 2414 2415 static constexpr const FastmathFlags fastmathFlagsList[] = { 2416 // clang-format off 2417 FastmathFlags::nnan, 2418 FastmathFlags::ninf, 2419 FastmathFlags::nsz, 2420 FastmathFlags::arcp, 2421 FastmathFlags::contract, 2422 FastmathFlags::afn, 2423 FastmathFlags::reassoc, 2424 FastmathFlags::fast, 2425 // clang-format on 2426 }; 2427 2428 void FMFAttr::print(AsmPrinter &printer) const { 2429 printer << "<"; 2430 auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { 2431 return bitEnumContains(this->getFlags(), flag); 2432 }); 2433 llvm::interleaveComma(flags, printer, 2434 [&](auto flag) { printer << stringifyEnum(flag); }); 2435 printer << ">"; 2436 } 2437 2438 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2439 if (failed(parser.parseLess())) 2440 return {}; 2441 2442 FastmathFlags flags = {}; 2443 if (failed(parser.parseOptionalGreater())) { 2444 do { 2445 StringRef elemName; 2446 if (failed(parser.parseKeyword(&elemName))) 2447 return {}; 2448 2449 auto elem = symbolizeFastmathFlags(elemName); 2450 if (!elem) { 2451 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2452 << elemName; 2453 return {}; 2454 } 2455 2456 flags = flags | *elem; 2457 } while (succeeded(parser.parseOptionalComma())); 2458 2459 if (failed(parser.parseGreater())) 2460 return {}; 2461 } 2462 2463 return FMFAttr::get(parser.getContext(), flags); 2464 } 2465 2466 void LinkageAttr::print(AsmPrinter &printer) const { 2467 printer << "<"; 2468 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2469 printer << stringifyEnum(getLinkage()); 2470 else 2471 printer << static_cast<uint64_t>(getLinkage()); 2472 printer << ">"; 2473 } 2474 2475 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2476 StringRef elemName; 2477 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2478 parser.parseGreater()) 2479 return {}; 2480 auto elem = linkage::symbolizeLinkage(elemName); 2481 if (!elem) { 2482 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2483 return {}; 2484 } 2485 Linkage linkage = *elem; 2486 return LinkageAttr::get(parser.getContext(), linkage); 2487 } 2488 2489 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2490 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2491 2492 template <typename T> 2493 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2494 Optional<T> value) { 2495 auto option = llvm::find_if( 2496 options, [tag](auto option) { return option.first == tag; }); 2497 if (option != options.end()) { 2498 if (value.hasValue()) 2499 option->second = *value; 2500 else 2501 options.erase(option); 2502 } else { 2503 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2504 } 2505 return *this; 2506 } 2507 2508 LoopOptionsAttrBuilder & 2509 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 2510 return setOption(LoopOptionCase::disable_licm, value); 2511 } 2512 2513 /// Set the `interleave_count` option to the provided value. If no value 2514 /// is provided the option is deleted. 2515 LoopOptionsAttrBuilder & 2516 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 2517 return setOption(LoopOptionCase::interleave_count, count); 2518 } 2519 2520 /// Set the `disable_unroll` option to the provided value. If no value 2521 /// is provided the option is deleted. 2522 LoopOptionsAttrBuilder & 2523 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 2524 return setOption(LoopOptionCase::disable_unroll, value); 2525 } 2526 2527 /// Set the `disable_pipeline` option to the provided value. If no value 2528 /// is provided the option is deleted. 2529 LoopOptionsAttrBuilder & 2530 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 2531 return setOption(LoopOptionCase::disable_pipeline, value); 2532 } 2533 2534 /// Set the `pipeline_initiation_interval` option to the provided value. 2535 /// If no value is provided the option is deleted. 2536 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 2537 Optional<uint64_t> count) { 2538 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 2539 } 2540 2541 template <typename T> 2542 static Optional<T> 2543 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 2544 LoopOptionCase option) { 2545 auto it = 2546 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 2547 return optionPair.first < option; 2548 }); 2549 if (it == options.end()) 2550 return {}; 2551 return static_cast<T>(it->second); 2552 } 2553 2554 Optional<bool> LoopOptionsAttr::disableUnroll() { 2555 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 2556 } 2557 2558 Optional<bool> LoopOptionsAttr::disableLICM() { 2559 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 2560 } 2561 2562 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 2563 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 2564 } 2565 2566 /// Build the LoopOptions Attribute from a sorted array of individual options. 2567 LoopOptionsAttr LoopOptionsAttr::get( 2568 MLIRContext *context, 2569 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 2570 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 2571 "LoopOptionsAttr ctor expects a sorted options array"); 2572 return Base::get(context, sortedOptions); 2573 } 2574 2575 /// Build the LoopOptions Attribute from a sorted array of individual options. 2576 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 2577 LoopOptionsAttrBuilder &optionBuilders) { 2578 llvm::sort(optionBuilders.options, llvm::less_first()); 2579 return Base::get(context, optionBuilders.options); 2580 } 2581 2582 void LoopOptionsAttr::print(AsmPrinter &printer) const { 2583 printer << "<"; 2584 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 2585 printer << stringifyEnum(option.first) << " = "; 2586 switch (option.first) { 2587 case LoopOptionCase::disable_licm: 2588 case LoopOptionCase::disable_unroll: 2589 case LoopOptionCase::disable_pipeline: 2590 printer << (option.second ? "true" : "false"); 2591 break; 2592 case LoopOptionCase::interleave_count: 2593 case LoopOptionCase::pipeline_initiation_interval: 2594 printer << option.second; 2595 break; 2596 } 2597 }); 2598 printer << ">"; 2599 } 2600 2601 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 2602 if (failed(parser.parseLess())) 2603 return {}; 2604 2605 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 2606 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 2607 do { 2608 StringRef optionName; 2609 if (parser.parseKeyword(&optionName)) 2610 return {}; 2611 2612 auto option = symbolizeLoopOptionCase(optionName); 2613 if (!option) { 2614 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 2615 << optionName; 2616 return {}; 2617 } 2618 if (!seenOptions.insert(*option).second) { 2619 parser.emitError(parser.getNameLoc(), "loop option present twice"); 2620 return {}; 2621 } 2622 if (failed(parser.parseEqual())) 2623 return {}; 2624 2625 int64_t value; 2626 switch (*option) { 2627 case LoopOptionCase::disable_licm: 2628 case LoopOptionCase::disable_unroll: 2629 case LoopOptionCase::disable_pipeline: 2630 if (succeeded(parser.parseOptionalKeyword("true"))) 2631 value = 1; 2632 else if (succeeded(parser.parseOptionalKeyword("false"))) 2633 value = 0; 2634 else { 2635 parser.emitError(parser.getNameLoc(), 2636 "expected boolean value 'true' or 'false'"); 2637 return {}; 2638 } 2639 break; 2640 case LoopOptionCase::interleave_count: 2641 case LoopOptionCase::pipeline_initiation_interval: 2642 if (failed(parser.parseInteger(value))) { 2643 parser.emitError(parser.getNameLoc(), "expected integer value"); 2644 return {}; 2645 } 2646 break; 2647 } 2648 options.push_back(std::make_pair(*option, value)); 2649 } while (succeeded(parser.parseOptionalComma())); 2650 if (failed(parser.parseGreater())) 2651 return {}; 2652 2653 llvm::sort(options, llvm::less_first()); 2654 return get(parser.getContext(), options); 2655 } 2656 2657 Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, 2658 Type type) const { 2659 if (type) { 2660 parser.emitError(parser.getNameLoc(), "unexpected type"); 2661 return {}; 2662 } 2663 StringRef attrKind; 2664 if (parser.parseKeyword(&attrKind)) 2665 return {}; 2666 { 2667 Attribute attr; 2668 auto parseResult = generatedAttributeParser(parser, attrKind, type, attr); 2669 if (parseResult.hasValue()) 2670 return attr; 2671 } 2672 parser.emitError(parser.getNameLoc(), "unknown attribute type: ") << attrKind; 2673 return {}; 2674 } 2675 2676 void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { 2677 if (succeeded(generatedAttributePrinter(attr, os))) 2678 return; 2679 llvm_unreachable("Unknown attribute type"); 2680 } 2681