1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/DialectImplementation.h" 19 #include "mlir/IR/FunctionImplementation.h" 20 #include "mlir/IR/MLIRContext.h" 21 22 #include "llvm/ADT/StringSwitch.h" 23 #include "llvm/AsmParser/Parser.h" 24 #include "llvm/Bitcode/BitcodeReader.h" 25 #include "llvm/Bitcode/BitcodeWriter.h" 26 #include "llvm/IR/Attributes.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/Type.h" 29 #include "llvm/Support/Mutex.h" 30 #include "llvm/Support/SourceMgr.h" 31 32 using namespace mlir; 33 using namespace mlir::LLVM; 34 35 static constexpr const char kVolatileAttrName[] = "volatile_"; 36 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 37 38 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 39 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 40 41 namespace mlir { 42 namespace LLVM { 43 namespace detail { 44 struct BitmaskEnumStorage : public AttributeStorage { 45 using KeyTy = uint64_t; 46 47 BitmaskEnumStorage(KeyTy val) : value(val) {} 48 49 bool operator==(const KeyTy &key) const { return value == key; } 50 51 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, 52 const KeyTy &key) { 53 return new (allocator.allocate<BitmaskEnumStorage>()) 54 BitmaskEnumStorage(key); 55 } 56 57 KeyTy value = 0; 58 }; 59 60 struct LoopOptionAttrStorage : public AttributeStorage { 61 using KeyTy = std::pair<uint64_t, int32_t>; 62 63 explicit LoopOptionAttrStorage(uint64_t option, int32_t value) 64 : option(option), value(value) {} 65 66 bool operator==(const KeyTy &key) const { 67 return key == KeyTy(option, value); 68 } 69 70 static LoopOptionAttrStorage * 71 construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) { 72 return new (allocator.allocate<LoopOptionAttrStorage>()) 73 LoopOptionAttrStorage(key.first, key.second); 74 } 75 76 uint64_t option; 77 int32_t value; 78 }; 79 } // namespace detail 80 } // namespace LLVM 81 } // namespace mlir 82 83 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 84 SmallVector<NamedAttribute, 8> filteredAttrs( 85 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 86 if (attr.first == "fastmathFlags") { 87 auto defAttr = FMFAttr::get({}, attr.second.getContext()); 88 return defAttr != attr.second; 89 } 90 return true; 91 })); 92 return filteredAttrs; 93 } 94 95 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 96 NamedAttrList &result) { 97 return parser.parseOptionalAttrDict(result); 98 } 99 100 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 101 DictionaryAttr attrs) { 102 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 103 } 104 105 //===----------------------------------------------------------------------===// 106 // Printing/parsing for LLVM::CmpOp. 107 //===----------------------------------------------------------------------===// 108 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { 109 p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) 110 << "\" " << op.getOperand(0) << ", " << op.getOperand(1); 111 p.printOptionalAttrDict(op->getAttrs(), {"predicate"}); 112 p << " : " << op.lhs().getType(); 113 } 114 115 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { 116 p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) 117 << "\" " << op.getOperand(0) << ", " << op.getOperand(1); 118 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"}); 119 p << " : " << op.lhs().getType(); 120 } 121 122 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 123 // attribute-dict? `:` type 124 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 125 // attribute-dict? `:` type 126 template <typename CmpPredicateType> 127 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 128 Builder &builder = parser.getBuilder(); 129 130 StringAttr predicateAttr; 131 OpAsmParser::OperandType lhs, rhs; 132 Type type; 133 llvm::SMLoc predicateLoc, trailingTypeLoc; 134 if (parser.getCurrentLocation(&predicateLoc) || 135 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 136 parser.parseOperand(lhs) || parser.parseComma() || 137 parser.parseOperand(rhs) || 138 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 139 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 140 parser.resolveOperand(lhs, type, result.operands) || 141 parser.resolveOperand(rhs, type, result.operands)) 142 return failure(); 143 144 // Replace the string attribute `predicate` with an integer attribute. 145 int64_t predicateValue = 0; 146 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 147 Optional<ICmpPredicate> predicate = 148 symbolizeICmpPredicate(predicateAttr.getValue()); 149 if (!predicate) 150 return parser.emitError(predicateLoc) 151 << "'" << predicateAttr.getValue() 152 << "' is an incorrect value of the 'predicate' attribute"; 153 predicateValue = static_cast<int64_t>(predicate.getValue()); 154 } else { 155 Optional<FCmpPredicate> predicate = 156 symbolizeFCmpPredicate(predicateAttr.getValue()); 157 if (!predicate) 158 return parser.emitError(predicateLoc) 159 << "'" << predicateAttr.getValue() 160 << "' is an incorrect value of the 'predicate' attribute"; 161 predicateValue = static_cast<int64_t>(predicate.getValue()); 162 } 163 164 result.attributes.set("predicate", 165 parser.getBuilder().getI64IntegerAttr(predicateValue)); 166 167 // The result type is either i1 or a vector type <? x i1> if the inputs are 168 // vectors. 169 Type resultType = IntegerType::get(builder.getContext(), 1); 170 if (!isCompatibleType(type)) 171 return parser.emitError(trailingTypeLoc, 172 "expected LLVM dialect-compatible type"); 173 if (LLVM::isCompatibleVectorType(type)) 174 resultType = LLVM::getFixedVectorType( 175 resultType, LLVM::getVectorNumElements(type).getFixedValue()); 176 assert(!type.isa<LLVM::LLVMScalableVectorType>() && 177 "unhandled scalable vector"); 178 179 result.addTypes({resultType}); 180 return success(); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // Printing/parsing for LLVM::AllocaOp. 185 //===----------------------------------------------------------------------===// 186 187 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { 188 auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType(); 189 190 auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()}, 191 {op.getType()}); 192 193 p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; 194 if (op.alignment().hasValue() && *op.alignment() != 0) 195 p.printOptionalAttrDict(op->getAttrs()); 196 else 197 p.printOptionalAttrDict(op->getAttrs(), {"alignment"}); 198 p << " : " << funcTy; 199 } 200 201 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 202 // `:` type `,` type 203 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { 204 OpAsmParser::OperandType arraySize; 205 Type type, elemType; 206 llvm::SMLoc trailingTypeLoc; 207 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 208 parser.parseType(elemType) || 209 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 210 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 211 return failure(); 212 213 Optional<NamedAttribute> alignmentAttr = 214 result.attributes.getNamed("alignment"); 215 if (alignmentAttr.hasValue()) { 216 auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>(); 217 if (!alignmentInt) 218 return parser.emitError(parser.getNameLoc(), 219 "expected integer alignment"); 220 if (alignmentInt.getValue().isNullValue()) 221 result.attributes.erase("alignment"); 222 } 223 224 // Extract the result type from the trailing function type. 225 auto funcType = type.dyn_cast<FunctionType>(); 226 if (!funcType || funcType.getNumInputs() != 1 || 227 funcType.getNumResults() != 1) 228 return parser.emitError( 229 trailingTypeLoc, 230 "expected trailing function type with one argument and one result"); 231 232 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 233 return failure(); 234 235 result.addTypes({funcType.getResult(0)}); 236 return success(); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // LLVM::BrOp 241 //===----------------------------------------------------------------------===// 242 243 Optional<MutableOperandRange> 244 BrOp::getMutableSuccessorOperands(unsigned index) { 245 assert(index == 0 && "invalid successor index"); 246 return destOperandsMutable(); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // LLVM::CondBrOp 251 //===----------------------------------------------------------------------===// 252 253 Optional<MutableOperandRange> 254 CondBrOp::getMutableSuccessorOperands(unsigned index) { 255 assert(index < getNumSuccessors() && "invalid successor index"); 256 return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); 257 } 258 259 //===----------------------------------------------------------------------===// 260 // LLVM::SwitchOp 261 //===----------------------------------------------------------------------===// 262 263 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 264 Block *defaultDestination, ValueRange defaultOperands, 265 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 266 ArrayRef<ValueRange> caseOperands, 267 ArrayRef<int32_t> branchWeights) { 268 SmallVector<Value> flattenedCaseOperands; 269 SmallVector<int32_t> caseOperandOffsets; 270 int32_t offset = 0; 271 for (ValueRange operands : caseOperands) { 272 flattenedCaseOperands.append(operands.begin(), operands.end()); 273 caseOperandOffsets.push_back(offset); 274 offset += operands.size(); 275 } 276 ElementsAttr caseValuesAttr; 277 if (!caseValues.empty()) 278 caseValuesAttr = builder.getI32VectorAttr(caseValues); 279 ElementsAttr caseOperandOffsetsAttr; 280 if (!caseOperandOffsets.empty()) 281 caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); 282 283 ElementsAttr weightsAttr; 284 if (!branchWeights.empty()) 285 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 286 287 build(builder, result, value, defaultOperands, flattenedCaseOperands, 288 caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination, 289 caseDestinations); 290 } 291 292 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 293 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 294 static ParseResult 295 parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues, 296 SmallVectorImpl<Block *> &caseDestinations, 297 SmallVectorImpl<OpAsmParser::OperandType> &caseOperands, 298 SmallVectorImpl<Type> &caseOperandTypes, 299 ElementsAttr &caseOperandOffsets) { 300 SmallVector<int32_t> values; 301 SmallVector<int32_t> offsets; 302 int32_t value, offset = 0; 303 do { 304 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 305 if (values.empty() && !integerParseResult.hasValue()) 306 return success(); 307 308 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 309 return failure(); 310 values.push_back(value); 311 312 Block *destination; 313 SmallVector<OpAsmParser::OperandType> operands; 314 if (parser.parseColon() || parser.parseSuccessor(destination)) 315 return failure(); 316 if (!parser.parseOptionalLParen()) { 317 if (parser.parseRegionArgumentList(operands) || 318 parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen()) 319 return failure(); 320 } 321 caseDestinations.push_back(destination); 322 caseOperands.append(operands.begin(), operands.end()); 323 offsets.push_back(offset); 324 offset += operands.size(); 325 } while (!parser.parseOptionalComma()); 326 327 Builder &builder = parser.getBuilder(); 328 caseValues = builder.getI32VectorAttr(values); 329 caseOperandOffsets = builder.getI32VectorAttr(offsets); 330 331 return success(); 332 } 333 334 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, 335 ElementsAttr caseValues, 336 SuccessorRange caseDestinations, 337 OperandRange caseOperands, 338 TypeRange caseOperandTypes, 339 ElementsAttr caseOperandOffsets) { 340 if (!caseValues) 341 return; 342 343 size_t index = 0; 344 llvm::interleave( 345 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 346 [&](auto i) { 347 p << " "; 348 p << std::get<0>(i).getLimitedValue(); 349 p << ": "; 350 p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++)); 351 }, 352 [&] { 353 p << ','; 354 p.printNewline(); 355 }); 356 p.printNewline(); 357 } 358 359 static LogicalResult verify(SwitchOp op) { 360 if ((!op.case_values() && !op.caseDestinations().empty()) || 361 (op.case_values() && 362 op.case_values()->size() != 363 static_cast<int64_t>(op.caseDestinations().size()))) 364 return op.emitOpError("expects number of case values to match number of " 365 "case destinations"); 366 if (op.branch_weights() && 367 op.branch_weights()->size() != op.getNumSuccessors()) 368 return op.emitError("expects number of branch weights to match number of " 369 "successors: ") 370 << op.branch_weights()->size() << " vs " << op.getNumSuccessors(); 371 return success(); 372 } 373 374 OperandRange SwitchOp::getCaseOperands(unsigned index) { 375 return getCaseOperandsMutable(index); 376 } 377 378 MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { 379 MutableOperandRange caseOperands = caseOperandsMutable(); 380 if (!case_operand_offsets()) { 381 assert(caseOperands.size() == 0 && 382 "non-empty case operands must have offsets"); 383 return caseOperands; 384 } 385 386 ElementsAttr offsets = case_operand_offsets().getValue(); 387 assert(index < offsets.size() && "invalid case operand offset index"); 388 389 int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt(); 390 int64_t end = index + 1 == offsets.size() 391 ? caseOperands.size() 392 : offsets.getValue(index + 1).cast<IntegerAttr>().getInt(); 393 return caseOperandsMutable().slice(begin, end - begin); 394 } 395 396 Optional<MutableOperandRange> 397 SwitchOp::getMutableSuccessorOperands(unsigned index) { 398 assert(index < getNumSuccessors() && "invalid successor index"); 399 return index == 0 ? defaultOperandsMutable() 400 : getCaseOperandsMutable(index - 1); 401 } 402 403 //===----------------------------------------------------------------------===// 404 // Builder, printer and parser for for LLVM::LoadOp. 405 //===----------------------------------------------------------------------===// 406 407 static LogicalResult verifyAccessGroups(Operation *op) { 408 if (Attribute attribute = 409 op->getAttr(LLVMDialect::getAccessGroupsAttrName())) { 410 // The attribute is already verified to be a symbol ref array attribute via 411 // a constraint in the operation definition. 412 for (SymbolRefAttr accessGroupRef : 413 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 414 StringRef metadataName = accessGroupRef.getRootReference(); 415 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 416 op->getParentOp(), metadataName); 417 if (!metadataOp) 418 return op->emitOpError() << "expected '" << accessGroupRef 419 << "' to reference a metadata op"; 420 StringRef accessGroupName = accessGroupRef.getLeafReference(); 421 Operation *accessGroupOp = 422 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 423 if (!accessGroupOp) 424 return op->emitOpError() << "expected '" << accessGroupRef 425 << "' to reference an access_group op"; 426 } 427 } 428 return success(); 429 } 430 431 static LogicalResult verify(LoadOp op) { 432 return verifyAccessGroups(op.getOperation()); 433 } 434 435 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 436 Value addr, unsigned alignment, bool isVolatile, 437 bool isNonTemporal) { 438 result.addOperands(addr); 439 result.addTypes(t); 440 if (isVolatile) 441 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 442 if (isNonTemporal) 443 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 444 if (alignment != 0) 445 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 446 } 447 448 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { 449 p << op.getOperationName() << ' '; 450 if (op.volatile_()) 451 p << "volatile "; 452 p << op.addr(); 453 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 454 p << " : " << op.addr().getType(); 455 } 456 457 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 458 // the resulting type wrapped in MLIR, or nullptr on error. 459 static Type getLoadStoreElementType(OpAsmParser &parser, Type type, 460 llvm::SMLoc trailingTypeLoc) { 461 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 462 if (!llvmTy) 463 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), 464 nullptr; 465 return llvmTy.getElementType(); 466 } 467 468 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 469 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { 470 OpAsmParser::OperandType addr; 471 Type type; 472 llvm::SMLoc trailingTypeLoc; 473 474 if (succeeded(parser.parseOptionalKeyword("volatile"))) 475 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 476 477 if (parser.parseOperand(addr) || 478 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 479 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 480 parser.resolveOperand(addr, type, result.operands)) 481 return failure(); 482 483 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 484 485 result.addTypes(elemTy); 486 return success(); 487 } 488 489 //===----------------------------------------------------------------------===// 490 // Builder, printer and parser for LLVM::StoreOp. 491 //===----------------------------------------------------------------------===// 492 493 static LogicalResult verify(StoreOp op) { 494 return verifyAccessGroups(op.getOperation()); 495 } 496 497 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 498 Value addr, unsigned alignment, bool isVolatile, 499 bool isNonTemporal) { 500 result.addOperands({value, addr}); 501 result.addTypes({}); 502 if (isVolatile) 503 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 504 if (isNonTemporal) 505 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 506 if (alignment != 0) 507 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 508 } 509 510 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { 511 p << op.getOperationName() << ' '; 512 if (op.volatile_()) 513 p << "volatile "; 514 p << op.value() << ", " << op.addr(); 515 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 516 p << " : " << op.addr().getType(); 517 } 518 519 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 520 // attribute-dict? `:` type 521 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { 522 OpAsmParser::OperandType addr, value; 523 Type type; 524 llvm::SMLoc trailingTypeLoc; 525 526 if (succeeded(parser.parseOptionalKeyword("volatile"))) 527 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 528 529 if (parser.parseOperand(value) || parser.parseComma() || 530 parser.parseOperand(addr) || 531 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 532 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 533 return failure(); 534 535 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 536 if (!elemTy) 537 return failure(); 538 539 if (parser.resolveOperand(value, elemTy, result.operands) || 540 parser.resolveOperand(addr, type, result.operands)) 541 return failure(); 542 543 return success(); 544 } 545 546 ///===---------------------------------------------------------------------===// 547 /// LLVM::InvokeOp 548 ///===---------------------------------------------------------------------===// 549 550 Optional<MutableOperandRange> 551 InvokeOp::getMutableSuccessorOperands(unsigned index) { 552 assert(index < getNumSuccessors() && "invalid successor index"); 553 return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable(); 554 } 555 556 static LogicalResult verify(InvokeOp op) { 557 if (op.getNumResults() > 1) 558 return op.emitOpError("must have 0 or 1 result"); 559 560 Block *unwindDest = op.unwindDest(); 561 if (unwindDest->empty()) 562 return op.emitError( 563 "must have at least one operation in unwind destination"); 564 565 // In unwind destination, first operation must be LandingpadOp 566 if (!isa<LandingpadOp>(unwindDest->front())) 567 return op.emitError("first operation in unwind destination should be a " 568 "llvm.landingpad operation"); 569 570 return success(); 571 } 572 573 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { 574 auto callee = op.callee(); 575 bool isDirect = callee.hasValue(); 576 577 p << op.getOperationName() << ' '; 578 579 // Either function name or pointer 580 if (isDirect) 581 p.printSymbolName(callee.getValue()); 582 else 583 p << op.getOperand(0); 584 585 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 586 p << " to "; 587 p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands()); 588 p << " unwind "; 589 p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands()); 590 591 p.printOptionalAttrDict(op->getAttrs(), 592 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 593 p << " : "; 594 p.printFunctionalType( 595 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1), 596 op.getResultTypes()); 597 } 598 599 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 600 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 601 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 602 /// attribute-dict? `:` function-type 603 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { 604 SmallVector<OpAsmParser::OperandType, 8> operands; 605 FunctionType funcType; 606 SymbolRefAttr funcAttr; 607 llvm::SMLoc trailingTypeLoc; 608 Block *normalDest, *unwindDest; 609 SmallVector<Value, 4> normalOperands, unwindOperands; 610 Builder &builder = parser.getBuilder(); 611 612 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 613 // case of an indirect call, there will be 1 operand before `(`. In case of a 614 // direct call, there will be no operands and the parser will stop at the 615 // function identifier without complaining. 616 if (parser.parseOperandList(operands)) 617 return failure(); 618 bool isDirect = operands.empty(); 619 620 // Optionally parse a function identifier. 621 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 622 return failure(); 623 624 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 625 parser.parseKeyword("to") || 626 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 627 parser.parseKeyword("unwind") || 628 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 629 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 630 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 631 return failure(); 632 633 if (isDirect) { 634 // Make sure types match. 635 if (parser.resolveOperands(operands, funcType.getInputs(), 636 parser.getNameLoc(), result.operands)) 637 return failure(); 638 result.addTypes(funcType.getResults()); 639 } else { 640 // Construct the LLVM IR Dialect function type that the first operand 641 // should match. 642 if (funcType.getNumResults() > 1) 643 return parser.emitError(trailingTypeLoc, 644 "expected function with 0 or 1 result"); 645 646 Type llvmResultType; 647 if (funcType.getNumResults() == 0) { 648 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 649 } else { 650 llvmResultType = funcType.getResult(0); 651 if (!isCompatibleType(llvmResultType)) 652 return parser.emitError(trailingTypeLoc, 653 "expected result to have LLVM type"); 654 } 655 656 SmallVector<Type, 8> argTypes; 657 argTypes.reserve(funcType.getNumInputs()); 658 for (Type ty : funcType.getInputs()) { 659 if (isCompatibleType(ty)) 660 argTypes.push_back(ty); 661 else 662 return parser.emitError(trailingTypeLoc, 663 "expected LLVM types as inputs"); 664 } 665 666 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 667 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 668 669 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 670 671 // Make sure that the first operand (indirect callee) matches the wrapped 672 // LLVM IR function type, and that the types of the other call operands 673 // match the types of the function arguments. 674 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 675 parser.resolveOperands(funcArguments, funcType.getInputs(), 676 parser.getNameLoc(), result.operands)) 677 return failure(); 678 679 result.addTypes(llvmResultType); 680 } 681 result.addSuccessors({normalDest, unwindDest}); 682 result.addOperands(normalOperands); 683 result.addOperands(unwindOperands); 684 685 result.addAttribute( 686 InvokeOp::getOperandSegmentSizeAttr(), 687 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 688 static_cast<int32_t>(normalOperands.size()), 689 static_cast<int32_t>(unwindOperands.size())})); 690 return success(); 691 } 692 693 ///===----------------------------------------------------------------------===// 694 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 695 ///===----------------------------------------------------------------------===// 696 697 static LogicalResult verify(LandingpadOp op) { 698 Value value; 699 if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) { 700 if (!func.personality().hasValue()) 701 return op.emitError( 702 "llvm.landingpad needs to be in a function with a personality"); 703 } 704 705 if (!op.cleanup() && op.getOperands().empty()) 706 return op.emitError("landingpad instruction expects at least one clause or " 707 "cleanup attribute"); 708 709 for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { 710 value = op.getOperand(idx); 711 bool isFilter = value.getType().isa<LLVMArrayType>(); 712 if (isFilter) { 713 // FIXME: Verify filter clauses when arrays are appropriately handled 714 } else { 715 // catch - global addresses only. 716 // Bitcast ops should have global addresses as their args. 717 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 718 if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>()) 719 continue; 720 return op.emitError("constant clauses expected") 721 .attachNote(bcOp.getLoc()) 722 << "global addresses expected as operand to " 723 "bitcast used in clauses for landingpad"; 724 } 725 // NullOp and AddressOfOp allowed 726 if (value.getDefiningOp<NullOp>()) 727 continue; 728 if (value.getDefiningOp<AddressOfOp>()) 729 continue; 730 return op.emitError("clause #") 731 << idx << " is not a known constant - null, addressof, bitcast"; 732 } 733 } 734 return success(); 735 } 736 737 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { 738 p << op.getOperationName() << (op.cleanup() ? " cleanup " : " "); 739 740 // Clauses 741 for (auto value : op.getOperands()) { 742 // Similar to llvm - if clause is an array type then it is filter 743 // clause else catch clause 744 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 745 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 746 << value.getType() << ") "; 747 } 748 749 p.printOptionalAttrDict(op->getAttrs(), {"cleanup"}); 750 751 p << ": " << op.getType(); 752 } 753 754 /// <operation> ::= `llvm.landingpad` `cleanup`? 755 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 756 static ParseResult parseLandingpadOp(OpAsmParser &parser, 757 OperationState &result) { 758 // Check for cleanup 759 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 760 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 761 762 // Parse clauses with types 763 while (succeeded(parser.parseOptionalLParen()) && 764 (succeeded(parser.parseOptionalKeyword("filter")) || 765 succeeded(parser.parseOptionalKeyword("catch")))) { 766 OpAsmParser::OperandType operand; 767 Type ty; 768 if (parser.parseOperand(operand) || parser.parseColon() || 769 parser.parseType(ty) || 770 parser.resolveOperand(operand, ty, result.operands) || 771 parser.parseRParen()) 772 return failure(); 773 } 774 775 Type type; 776 if (parser.parseColon() || parser.parseType(type)) 777 return failure(); 778 779 result.addTypes(type); 780 return success(); 781 } 782 783 //===----------------------------------------------------------------------===// 784 // Verifying/Printing/parsing for LLVM::CallOp. 785 //===----------------------------------------------------------------------===// 786 787 static LogicalResult verify(CallOp &op) { 788 if (op.getNumResults() > 1) 789 return op.emitOpError("must have 0 or 1 result"); 790 791 // Type for the callee, we'll get it differently depending if it is a direct 792 // or indirect call. 793 Type fnType; 794 795 bool isIndirect = false; 796 797 // If this is an indirect call, the callee attribute is missing. 798 Optional<StringRef> calleeName = op.callee(); 799 if (!calleeName) { 800 isIndirect = true; 801 if (!op.getNumOperands()) 802 return op.emitOpError( 803 "must have either a `callee` attribute or at least an operand"); 804 auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>(); 805 if (!ptrType) 806 return op.emitOpError("indirect call expects a pointer as callee: ") 807 << ptrType; 808 fnType = ptrType.getElementType(); 809 } else { 810 Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName); 811 if (!callee) 812 return op.emitOpError() 813 << "'" << *calleeName 814 << "' does not reference a symbol in the current scope"; 815 auto fn = dyn_cast<LLVMFuncOp>(callee); 816 if (!fn) 817 return op.emitOpError() << "'" << *calleeName 818 << "' does not reference a valid LLVM function"; 819 820 fnType = fn.getType(); 821 } 822 823 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 824 if (!funcType) 825 return op.emitOpError("callee does not have a functional type: ") << fnType; 826 827 // Verify that the operand and result types match the callee. 828 829 if (!funcType.isVarArg() && 830 funcType.getNumParams() != (op.getNumOperands() - isIndirect)) 831 return op.emitOpError() 832 << "incorrect number of operands (" 833 << (op.getNumOperands() - isIndirect) 834 << ") for callee (expecting: " << funcType.getNumParams() << ")"; 835 836 if (funcType.getNumParams() > (op.getNumOperands() - isIndirect)) 837 return op.emitOpError() << "incorrect number of operands (" 838 << (op.getNumOperands() - isIndirect) 839 << ") for varargs callee (expecting at least: " 840 << funcType.getNumParams() << ")"; 841 842 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 843 if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 844 return op.emitOpError() << "operand type mismatch for operand " << i 845 << ": " << op.getOperand(i + isIndirect).getType() 846 << " != " << funcType.getParamType(i); 847 848 if (op.getNumResults() && 849 op.getResult(0).getType() != funcType.getReturnType()) 850 return op.emitOpError() 851 << "result type mismatch: " << op.getResult(0).getType() 852 << " != " << funcType.getReturnType(); 853 854 return success(); 855 } 856 857 static void printCallOp(OpAsmPrinter &p, CallOp &op) { 858 auto callee = op.callee(); 859 bool isDirect = callee.hasValue(); 860 861 // Print the direct callee if present as a function attribute, or an indirect 862 // callee (first operand) otherwise. 863 p << op.getOperationName() << ' '; 864 if (isDirect) 865 p.printSymbolName(callee.getValue()); 866 else 867 p << op.getOperand(0); 868 869 auto args = op.getOperands().drop_front(isDirect ? 0 : 1); 870 p << '(' << args << ')'; 871 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"}); 872 873 // Reconstruct the function MLIR function type from operand and result types. 874 p << " : " 875 << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes()); 876 } 877 878 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 879 // attribute-dict? `:` function-type 880 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { 881 SmallVector<OpAsmParser::OperandType, 8> operands; 882 Type type; 883 SymbolRefAttr funcAttr; 884 llvm::SMLoc trailingTypeLoc; 885 886 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 887 // case of an indirect call, there will be 1 operand before `(`. In case of a 888 // direct call, there will be no operands and the parser will stop at the 889 // function identifier without complaining. 890 if (parser.parseOperandList(operands)) 891 return failure(); 892 bool isDirect = operands.empty(); 893 894 // Optionally parse a function identifier. 895 if (isDirect) 896 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 897 return failure(); 898 899 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 900 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 901 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 902 return failure(); 903 904 auto funcType = type.dyn_cast<FunctionType>(); 905 if (!funcType) 906 return parser.emitError(trailingTypeLoc, "expected function type"); 907 if (isDirect) { 908 // Make sure types match. 909 if (parser.resolveOperands(operands, funcType.getInputs(), 910 parser.getNameLoc(), result.operands)) 911 return failure(); 912 result.addTypes(funcType.getResults()); 913 } else { 914 // Construct the LLVM IR Dialect function type that the first operand 915 // should match. 916 if (funcType.getNumResults() > 1) 917 return parser.emitError(trailingTypeLoc, 918 "expected function with 0 or 1 result"); 919 920 Builder &builder = parser.getBuilder(); 921 Type llvmResultType; 922 if (funcType.getNumResults() == 0) { 923 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 924 } else { 925 llvmResultType = funcType.getResult(0); 926 if (!isCompatibleType(llvmResultType)) 927 return parser.emitError(trailingTypeLoc, 928 "expected result to have LLVM type"); 929 } 930 931 SmallVector<Type, 8> argTypes; 932 argTypes.reserve(funcType.getNumInputs()); 933 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 934 auto argType = funcType.getInput(i); 935 if (!isCompatibleType(argType)) 936 return parser.emitError(trailingTypeLoc, 937 "expected LLVM types as inputs"); 938 argTypes.push_back(argType); 939 } 940 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 941 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 942 943 auto funcArguments = 944 ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); 945 946 // Make sure that the first operand (indirect callee) matches the wrapped 947 // LLVM IR function type, and that the types of the other call operands 948 // match the types of the function arguments. 949 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 950 parser.resolveOperands(funcArguments, funcType.getInputs(), 951 parser.getNameLoc(), result.operands)) 952 return failure(); 953 954 result.addTypes(llvmResultType); 955 } 956 957 return success(); 958 } 959 960 //===----------------------------------------------------------------------===// 961 // Printing/parsing for LLVM::ExtractElementOp. 962 //===----------------------------------------------------------------------===// 963 // Expects vector to be of wrapped LLVM vector type and position to be of 964 // wrapped LLVM i32 type. 965 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 966 Value vector, Value position, 967 ArrayRef<NamedAttribute> attrs) { 968 auto vectorType = vector.getType(); 969 auto llvmType = LLVM::getVectorElementType(vectorType); 970 build(b, result, llvmType, vector, position); 971 result.addAttributes(attrs); 972 } 973 974 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { 975 p << op.getOperationName() << ' ' << op.vector() << "[" << op.position() 976 << " : " << op.position().getType() << "]"; 977 p.printOptionalAttrDict(op->getAttrs()); 978 p << " : " << op.vector().getType(); 979 } 980 981 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 982 // attribute-dict? `:` type 983 static ParseResult parseExtractElementOp(OpAsmParser &parser, 984 OperationState &result) { 985 llvm::SMLoc loc; 986 OpAsmParser::OperandType vector, position; 987 Type type, positionType; 988 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 989 parser.parseLSquare() || parser.parseOperand(position) || 990 parser.parseColonType(positionType) || parser.parseRSquare() || 991 parser.parseOptionalAttrDict(result.attributes) || 992 parser.parseColonType(type) || 993 parser.resolveOperand(vector, type, result.operands) || 994 parser.resolveOperand(position, positionType, result.operands)) 995 return failure(); 996 if (!LLVM::isCompatibleVectorType(type)) 997 return parser.emitError( 998 loc, "expected LLVM dialect-compatible vector type for operand #1"); 999 result.addTypes(LLVM::getVectorElementType(type)); 1000 return success(); 1001 } 1002 1003 //===----------------------------------------------------------------------===// 1004 // Printing/parsing for LLVM::ExtractValueOp. 1005 //===----------------------------------------------------------------------===// 1006 1007 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { 1008 p << op.getOperationName() << ' ' << op.container() << op.position(); 1009 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 1010 p << " : " << op.container().getType(); 1011 } 1012 1013 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1014 // `containerType`. Position is an integer array attribute where each value 1015 // is a zero-based position of the element in the aggregate type. Return the 1016 // resulting type wrapped in MLIR, or nullptr on error. 1017 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1018 Type containerType, 1019 ArrayAttr positionAttr, 1020 llvm::SMLoc attributeLoc, 1021 llvm::SMLoc typeLoc) { 1022 Type llvmType = containerType; 1023 if (!isCompatibleType(containerType)) 1024 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1025 1026 // Infer the element type from the structure type: iteratively step inside the 1027 // type by taking the element type, indexed by the position attribute for 1028 // structures. Check the position index before accessing, it is supposed to 1029 // be in bounds. 1030 for (Attribute subAttr : positionAttr) { 1031 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1032 if (!positionElementAttr) 1033 return parser.emitError(attributeLoc, 1034 "expected an array of integer literals"), 1035 nullptr; 1036 int position = positionElementAttr.getInt(); 1037 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1038 if (position < 0 || 1039 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1040 return parser.emitError(attributeLoc, "position out of bounds"), 1041 nullptr; 1042 llvmType = arrayType.getElementType(); 1043 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1044 if (position < 0 || 1045 static_cast<unsigned>(position) >= structType.getBody().size()) 1046 return parser.emitError(attributeLoc, "position out of bounds"), 1047 nullptr; 1048 llvmType = structType.getBody()[position]; 1049 } else { 1050 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1051 nullptr; 1052 } 1053 } 1054 return llvmType; 1055 } 1056 1057 // <operation> ::= `llvm.extractvalue` ssa-use 1058 // `[` integer-literal (`,` integer-literal)* `]` 1059 // attribute-dict? `:` type 1060 static ParseResult parseExtractValueOp(OpAsmParser &parser, 1061 OperationState &result) { 1062 OpAsmParser::OperandType container; 1063 Type containerType; 1064 ArrayAttr positionAttr; 1065 llvm::SMLoc attributeLoc, trailingTypeLoc; 1066 1067 if (parser.parseOperand(container) || 1068 parser.getCurrentLocation(&attributeLoc) || 1069 parser.parseAttribute(positionAttr, "position", result.attributes) || 1070 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1071 parser.getCurrentLocation(&trailingTypeLoc) || 1072 parser.parseType(containerType) || 1073 parser.resolveOperand(container, containerType, result.operands)) 1074 return failure(); 1075 1076 auto elementType = getInsertExtractValueElementType( 1077 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1078 if (!elementType) 1079 return failure(); 1080 1081 result.addTypes(elementType); 1082 return success(); 1083 } 1084 1085 //===----------------------------------------------------------------------===// 1086 // Printing/parsing for LLVM::InsertElementOp. 1087 //===----------------------------------------------------------------------===// 1088 1089 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { 1090 p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "[" 1091 << op.position() << " : " << op.position().getType() << "]"; 1092 p.printOptionalAttrDict(op->getAttrs()); 1093 p << " : " << op.vector().getType(); 1094 } 1095 1096 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1097 // attribute-dict? `:` type 1098 static ParseResult parseInsertElementOp(OpAsmParser &parser, 1099 OperationState &result) { 1100 llvm::SMLoc loc; 1101 OpAsmParser::OperandType vector, value, position; 1102 Type vectorType, positionType; 1103 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1104 parser.parseComma() || parser.parseOperand(vector) || 1105 parser.parseLSquare() || parser.parseOperand(position) || 1106 parser.parseColonType(positionType) || parser.parseRSquare() || 1107 parser.parseOptionalAttrDict(result.attributes) || 1108 parser.parseColonType(vectorType)) 1109 return failure(); 1110 1111 if (!LLVM::isCompatibleVectorType(vectorType)) 1112 return parser.emitError( 1113 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1114 Type valueType = LLVM::getVectorElementType(vectorType); 1115 if (!valueType) 1116 return failure(); 1117 1118 if (parser.resolveOperand(vector, vectorType, result.operands) || 1119 parser.resolveOperand(value, valueType, result.operands) || 1120 parser.resolveOperand(position, positionType, result.operands)) 1121 return failure(); 1122 1123 result.addTypes(vectorType); 1124 return success(); 1125 } 1126 1127 //===----------------------------------------------------------------------===// 1128 // Printing/parsing for LLVM::InsertValueOp. 1129 //===----------------------------------------------------------------------===// 1130 1131 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { 1132 p << op.getOperationName() << ' ' << op.value() << ", " << op.container() 1133 << op.position(); 1134 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 1135 p << " : " << op.container().getType(); 1136 } 1137 1138 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1139 // `[` integer-literal (`,` integer-literal)* `]` 1140 // attribute-dict? `:` type 1141 static ParseResult parseInsertValueOp(OpAsmParser &parser, 1142 OperationState &result) { 1143 OpAsmParser::OperandType container, value; 1144 Type containerType; 1145 ArrayAttr positionAttr; 1146 llvm::SMLoc attributeLoc, trailingTypeLoc; 1147 1148 if (parser.parseOperand(value) || parser.parseComma() || 1149 parser.parseOperand(container) || 1150 parser.getCurrentLocation(&attributeLoc) || 1151 parser.parseAttribute(positionAttr, "position", result.attributes) || 1152 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1153 parser.getCurrentLocation(&trailingTypeLoc) || 1154 parser.parseType(containerType)) 1155 return failure(); 1156 1157 auto valueType = getInsertExtractValueElementType( 1158 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1159 if (!valueType) 1160 return failure(); 1161 1162 if (parser.resolveOperand(container, containerType, result.operands) || 1163 parser.resolveOperand(value, valueType, result.operands)) 1164 return failure(); 1165 1166 result.addTypes(containerType); 1167 return success(); 1168 } 1169 1170 //===----------------------------------------------------------------------===// 1171 // Printing, parsing and verification for LLVM::ReturnOp. 1172 //===----------------------------------------------------------------------===// 1173 1174 static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { 1175 p << op.getOperationName(); 1176 p.printOptionalAttrDict(op->getAttrs()); 1177 assert(op.getNumOperands() <= 1); 1178 1179 if (op.getNumOperands() == 0) 1180 return; 1181 1182 p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); 1183 } 1184 1185 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` 1186 // type-list-no-parens 1187 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 1188 SmallVector<OpAsmParser::OperandType, 1> operands; 1189 Type type; 1190 1191 if (parser.parseOperandList(operands) || 1192 parser.parseOptionalAttrDict(result.attributes)) 1193 return failure(); 1194 if (operands.empty()) 1195 return success(); 1196 1197 if (parser.parseColonType(type) || 1198 parser.resolveOperand(operands[0], type, result.operands)) 1199 return failure(); 1200 return success(); 1201 } 1202 1203 static LogicalResult verify(ReturnOp op) { 1204 if (op->getNumOperands() > 1) 1205 return op->emitOpError("expected at most 1 operand"); 1206 1207 if (auto parent = op->getParentOfType<LLVMFuncOp>()) { 1208 Type expectedType = parent.getType().getReturnType(); 1209 if (expectedType.isa<LLVMVoidType>()) { 1210 if (op->getNumOperands() == 0) 1211 return success(); 1212 InFlightDiagnostic diag = op->emitOpError("expected no operands"); 1213 diag.attachNote(parent->getLoc()) << "when returning from function"; 1214 return diag; 1215 } 1216 if (op->getNumOperands() == 0) { 1217 if (expectedType.isa<LLVMVoidType>()) 1218 return success(); 1219 InFlightDiagnostic diag = op->emitOpError("expected 1 operand"); 1220 diag.attachNote(parent->getLoc()) << "when returning from function"; 1221 return diag; 1222 } 1223 if (expectedType != op->getOperand(0).getType()) { 1224 InFlightDiagnostic diag = op->emitOpError("mismatching result types"); 1225 diag.attachNote(parent->getLoc()) << "when returning from function"; 1226 return diag; 1227 } 1228 } 1229 return success(); 1230 } 1231 1232 //===----------------------------------------------------------------------===// 1233 // Verifier for LLVM::AddressOfOp. 1234 //===----------------------------------------------------------------------===// 1235 1236 template <typename OpTy> 1237 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1238 Operation *module = parent; 1239 while (module && !satisfiesLLVMModule(module)) 1240 module = module->getParentOp(); 1241 assert(module && "unexpected operation outside of a module"); 1242 return dyn_cast_or_null<OpTy>( 1243 mlir::SymbolTable::lookupSymbolIn(module, name)); 1244 } 1245 1246 GlobalOp AddressOfOp::getGlobal() { 1247 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1248 global_name()); 1249 } 1250 1251 LLVMFuncOp AddressOfOp::getFunction() { 1252 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1253 global_name()); 1254 } 1255 1256 static LogicalResult verify(AddressOfOp op) { 1257 auto global = op.getGlobal(); 1258 auto function = op.getFunction(); 1259 if (!global && !function) 1260 return op.emitOpError( 1261 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1262 1263 if (global && 1264 LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) != 1265 op.getResult().getType()) 1266 return op.emitOpError( 1267 "the type must be a pointer to the type of the referenced global"); 1268 1269 if (function && LLVM::LLVMPointerType::get(function.getType()) != 1270 op.getResult().getType()) 1271 return op.emitOpError( 1272 "the type must be a pointer to the type of the referenced function"); 1273 1274 return success(); 1275 } 1276 1277 //===----------------------------------------------------------------------===// 1278 // Builder, printer and verifier for LLVM::GlobalOp. 1279 //===----------------------------------------------------------------------===// 1280 1281 /// Returns the name used for the linkage attribute. This *must* correspond to 1282 /// the name of the attribute in ODS. 1283 static StringRef getLinkageAttrName() { return "linkage"; } 1284 1285 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1286 bool isConstant, Linkage linkage, StringRef name, 1287 Attribute value, unsigned addrSpace, 1288 ArrayRef<NamedAttribute> attrs) { 1289 result.addAttribute(SymbolTable::getSymbolAttrName(), 1290 builder.getStringAttr(name)); 1291 result.addAttribute("type", TypeAttr::get(type)); 1292 if (isConstant) 1293 result.addAttribute("constant", builder.getUnitAttr()); 1294 if (value) 1295 result.addAttribute("value", value); 1296 result.addAttribute(getLinkageAttrName(), 1297 builder.getI64IntegerAttr(static_cast<int64_t>(linkage))); 1298 if (addrSpace != 0) 1299 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); 1300 result.attributes.append(attrs.begin(), attrs.end()); 1301 result.addRegion(); 1302 } 1303 1304 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { 1305 p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' '; 1306 if (op.constant()) 1307 p << "constant "; 1308 p.printSymbolName(op.sym_name()); 1309 p << '('; 1310 if (auto value = op.getValueOrNull()) 1311 p.printAttribute(value); 1312 p << ')'; 1313 p.printOptionalAttrDict(op->getAttrs(), 1314 {SymbolTable::getSymbolAttrName(), "type", "constant", 1315 "value", getLinkageAttrName()}); 1316 1317 // Print the trailing type unless it's a string global. 1318 if (op.getValueOrNull().dyn_cast_or_null<StringAttr>()) 1319 return; 1320 p << " : " << op.type(); 1321 1322 Region &initializer = op.getInitializerRegion(); 1323 if (!initializer.empty()) 1324 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1325 } 1326 1327 //===----------------------------------------------------------------------===// 1328 // Verifier for LLVM::DialectCastOp. 1329 //===----------------------------------------------------------------------===// 1330 1331 /// Checks if `llvmType` is dialect cast-compatible with `index` type. Does not 1332 /// report the error, the user is expected to produce an appropriate message. 1333 // TODO: make the size depend on data layout rather than on the conversion 1334 // pass option, and pull that information here. 1335 static LogicalResult verifyCastWithIndex(Type llvmType) { 1336 return success(llvmType.isa<IntegerType>()); 1337 } 1338 1339 /// Checks if `llvmType` is dialect cast-compatible with built-in `type` and 1340 /// reports errors to the location of `op`. `isElement` indicates whether the 1341 /// verification is performed for types that are element types inside a 1342 /// container; we don't want casts from X to X at the top level, but c1<X> to 1343 /// c2<X> may be fine. 1344 static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type, 1345 bool isElement = false) { 1346 // Equal element types are directly compatible. 1347 if (isElement && llvmType == type) 1348 return success(); 1349 1350 // Index is compatible with any integer. 1351 if (type.isIndex()) { 1352 if (succeeded(verifyCastWithIndex(llvmType))) 1353 return success(); 1354 1355 return op.emitOpError("invalid cast between index and non-integer type"); 1356 } 1357 1358 if (type.isa<IntegerType>()) { 1359 auto llvmIntegerType = llvmType.dyn_cast<IntegerType>(); 1360 if (!llvmIntegerType) 1361 return op->emitOpError("invalid cast between integer and non-integer"); 1362 if (llvmIntegerType.getWidth() != type.getIntOrFloatBitWidth()) 1363 return op.emitOpError("invalid cast changing integer width"); 1364 return success(); 1365 } 1366 1367 // Vectors are compatible if they are 1D non-scalable, and their element types 1368 // are compatible. nD vectors are compatible with (n-1)D arrays containing 1D 1369 // vector. 1370 if (auto vectorType = type.dyn_cast<VectorType>()) { 1371 if (vectorType == llvmType && !isElement) 1372 return op.emitOpError("vector types should not be casted"); 1373 1374 if (vectorType.getRank() == 1) { 1375 auto llvmVectorType = llvmType.dyn_cast<VectorType>(); 1376 if (!llvmVectorType || llvmVectorType.getRank() != 1) 1377 return op.emitOpError("invalid cast for vector types"); 1378 1379 return verifyCast(op, llvmVectorType.getElementType(), 1380 vectorType.getElementType(), /*isElement=*/true); 1381 } 1382 1383 auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>(); 1384 if (!arrayType || 1385 arrayType.getNumElements() != vectorType.getShape().front()) 1386 return op.emitOpError("invalid cast for vector, expected array"); 1387 return verifyCast(op, arrayType.getElementType(), 1388 VectorType::get(vectorType.getShape().drop_front(), 1389 vectorType.getElementType()), 1390 /*isElement=*/true); 1391 } 1392 1393 if (auto memrefType = type.dyn_cast<MemRefType>()) { 1394 // Bare pointer convention: statically-shaped memref is compatible with an 1395 // LLVM pointer to the element type. 1396 if (auto ptrType = llvmType.dyn_cast<LLVMPointerType>()) { 1397 if (!memrefType.hasStaticShape()) 1398 return op->emitOpError( 1399 "unexpected bare pointer for dynamically shaped memref"); 1400 if (memrefType.getMemorySpaceAsInt() != ptrType.getAddressSpace()) 1401 return op->emitError("invalid conversion between memref and pointer in " 1402 "different memory spaces"); 1403 1404 return verifyCast(op, ptrType.getElementType(), 1405 memrefType.getElementType(), /*isElement=*/true); 1406 } 1407 1408 // Otherwise, memrefs are convertible to a descriptor, which is a structure 1409 // type. 1410 auto structType = llvmType.dyn_cast<LLVMStructType>(); 1411 if (!structType) 1412 return op->emitOpError("invalid cast between a memref and a type other " 1413 "than pointer or memref descriptor"); 1414 1415 unsigned expectedNumElements = memrefType.getRank() == 0 ? 3 : 5; 1416 if (structType.getBody().size() != expectedNumElements) { 1417 return op->emitOpError() << "expected memref descriptor with " 1418 << expectedNumElements << " elements"; 1419 } 1420 1421 // The first two elements are pointers to the element type. 1422 auto allocatedPtr = structType.getBody()[0].dyn_cast<LLVMPointerType>(); 1423 if (!allocatedPtr || 1424 allocatedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt()) 1425 return op->emitOpError("expected first element of a memref descriptor to " 1426 "be a pointer in the address space of the memref"); 1427 if (failed(verifyCast(op, allocatedPtr.getElementType(), 1428 memrefType.getElementType(), /*isElement=*/true))) 1429 return failure(); 1430 1431 auto alignedPtr = structType.getBody()[1].dyn_cast<LLVMPointerType>(); 1432 if (!alignedPtr || 1433 alignedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt()) 1434 return op->emitOpError( 1435 "expected second element of a memref descriptor to " 1436 "be a pointer in the address space of the memref"); 1437 if (failed(verifyCast(op, alignedPtr.getElementType(), 1438 memrefType.getElementType(), /*isElement=*/true))) 1439 return failure(); 1440 1441 // The second element (offset) is an equivalent of index. 1442 if (failed(verifyCastWithIndex(structType.getBody()[2]))) 1443 return op->emitOpError("expected third element of a memref descriptor to " 1444 "be index-compatible integers"); 1445 1446 // 0D memrefs don't have sizes/strides. 1447 if (memrefType.getRank() == 0) 1448 return success(); 1449 1450 // Sizes and strides are rank-sized arrays of `index` equivalents. 1451 auto sizes = structType.getBody()[3].dyn_cast<LLVMArrayType>(); 1452 if (!sizes || failed(verifyCastWithIndex(sizes.getElementType())) || 1453 sizes.getNumElements() != memrefType.getRank()) 1454 return op->emitOpError( 1455 "expected fourth element of a memref descriptor " 1456 "to be an array of <rank> index-compatible integers"); 1457 1458 auto strides = structType.getBody()[4].dyn_cast<LLVMArrayType>(); 1459 if (!strides || failed(verifyCastWithIndex(strides.getElementType())) || 1460 strides.getNumElements() != memrefType.getRank()) 1461 return op->emitOpError( 1462 "expected fifth element of a memref descriptor " 1463 "to be an array of <rank> index-compatible integers"); 1464 1465 return success(); 1466 } 1467 1468 // Unranked memrefs are compatible with their descriptors. 1469 if (auto unrankedMemrefType = type.dyn_cast<UnrankedMemRefType>()) { 1470 auto structType = llvmType.dyn_cast<LLVMStructType>(); 1471 if (!structType || structType.getBody().size() != 2) 1472 return op->emitOpError( 1473 "expected descriptor to be a struct with two elements"); 1474 1475 if (failed(verifyCastWithIndex(structType.getBody()[0]))) 1476 return op->emitOpError("expected first element of a memref descriptor to " 1477 "be an index-compatible integer"); 1478 1479 auto ptrType = structType.getBody()[1].dyn_cast<LLVMPointerType>(); 1480 auto ptrElementType = 1481 ptrType ? ptrType.getElementType().dyn_cast<IntegerType>() : nullptr; 1482 if (!ptrElementType || ptrElementType.getWidth() != 8) 1483 return op->emitOpError("expected second element of a memref descriptor " 1484 "to be an !llvm.ptr<i8>"); 1485 1486 return success(); 1487 } 1488 1489 // Complex types are compatible with the two-element structs. 1490 if (auto complexType = type.dyn_cast<ComplexType>()) { 1491 auto structType = llvmType.dyn_cast<LLVMStructType>(); 1492 if (!structType || structType.getBody().size() != 2 || 1493 structType.getBody()[0] != structType.getBody()[1] || 1494 structType.getBody()[0] != complexType.getElementType()) 1495 return op->emitOpError("expected 'complex' to map to two-element struct " 1496 "with identical element types"); 1497 return success(); 1498 } 1499 1500 // Everything else is not supported. 1501 return op->emitError("unsupported cast"); 1502 } 1503 1504 static LogicalResult verify(DialectCastOp op) { 1505 if (isCompatibleType(op.getType())) 1506 return verifyCast(op, op.getType(), op.in().getType()); 1507 1508 if (!isCompatibleType(op.in().getType())) 1509 return op->emitOpError("expected one LLVM type and one built-in type"); 1510 1511 return verifyCast(op, op.in().getType(), op.getType()); 1512 } 1513 1514 // Parses one of the keywords provided in the list `keywords` and returns the 1515 // position of the parsed keyword in the list. If none of the keywords from the 1516 // list is parsed, returns -1. 1517 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1518 ArrayRef<StringRef> keywords) { 1519 for (auto en : llvm::enumerate(keywords)) { 1520 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1521 return en.index(); 1522 } 1523 return -1; 1524 } 1525 1526 namespace { 1527 template <typename Ty> struct EnumTraits {}; 1528 1529 #define REGISTER_ENUM_TYPE(Ty) \ 1530 template <> struct EnumTraits<Ty> { \ 1531 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1532 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1533 } 1534 1535 REGISTER_ENUM_TYPE(Linkage); 1536 } // end namespace 1537 1538 template <typename EnumTy> 1539 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser, 1540 OperationState &result, 1541 StringRef name) { 1542 SmallVector<StringRef, 10> names; 1543 for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i) 1544 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1545 1546 int index = parseOptionalKeywordAlternative(parser, names); 1547 if (index == -1) 1548 return failure(); 1549 result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index)); 1550 return success(); 1551 } 1552 1553 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1554 // `(` attribute? `)` attribute-list? (`:` type)? region? 1555 // 1556 // The type can be omitted for string attributes, in which case it will be 1557 // inferred from the value of the string as [strlen(value) x i8]. 1558 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 1559 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, 1560 getLinkageAttrName()))) 1561 result.addAttribute(getLinkageAttrName(), 1562 parser.getBuilder().getI64IntegerAttr( 1563 static_cast<int64_t>(LLVM::Linkage::External))); 1564 1565 if (succeeded(parser.parseOptionalKeyword("constant"))) 1566 result.addAttribute("constant", parser.getBuilder().getUnitAttr()); 1567 1568 StringAttr name; 1569 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), 1570 result.attributes) || 1571 parser.parseLParen()) 1572 return failure(); 1573 1574 Attribute value; 1575 if (parser.parseOptionalRParen()) { 1576 if (parser.parseAttribute(value, "value", result.attributes) || 1577 parser.parseRParen()) 1578 return failure(); 1579 } 1580 1581 SmallVector<Type, 1> types; 1582 if (parser.parseOptionalAttrDict(result.attributes) || 1583 parser.parseOptionalColonTypeList(types)) 1584 return failure(); 1585 1586 if (types.size() > 1) 1587 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1588 1589 Region &initRegion = *result.addRegion(); 1590 if (types.empty()) { 1591 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1592 MLIRContext *context = parser.getBuilder().getContext(); 1593 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1594 strAttr.getValue().size()); 1595 types.push_back(arrayType); 1596 } else { 1597 return parser.emitError(parser.getNameLoc(), 1598 "type can only be omitted for string globals"); 1599 } 1600 } else { 1601 OptionalParseResult parseResult = 1602 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1603 /*argTypes=*/{}); 1604 if (parseResult.hasValue() && failed(*parseResult)) 1605 return failure(); 1606 } 1607 1608 result.addAttribute("type", TypeAttr::get(types[0])); 1609 return success(); 1610 } 1611 1612 static bool isZeroAttribute(Attribute value) { 1613 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1614 return intValue.getValue().isNullValue(); 1615 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1616 return fpValue.getValue().isZero(); 1617 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1618 return isZeroAttribute(splatValue.getSplatValue()); 1619 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1620 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1621 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1622 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1623 return false; 1624 } 1625 1626 static LogicalResult verify(GlobalOp op) { 1627 if (!LLVMPointerType::isValidElementType(op.getType())) 1628 return op.emitOpError( 1629 "expects type to be a valid element type for an LLVM pointer"); 1630 if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp())) 1631 return op.emitOpError("must appear at the module level"); 1632 1633 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1634 auto type = op.getType().dyn_cast<LLVMArrayType>(); 1635 IntegerType elementType = 1636 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1637 if (!elementType || elementType.getWidth() != 8 || 1638 type.getNumElements() != strAttr.getValue().size()) 1639 return op.emitOpError( 1640 "requires an i8 array type of the length equal to that of the string " 1641 "attribute"); 1642 } 1643 1644 if (Block *b = op.getInitializerBlock()) { 1645 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1646 if (ret.operand_type_begin() == ret.operand_type_end()) 1647 return op.emitOpError("initializer region cannot return void"); 1648 if (*ret.operand_type_begin() != op.getType()) 1649 return op.emitOpError("initializer region type ") 1650 << *ret.operand_type_begin() << " does not match global type " 1651 << op.getType(); 1652 1653 if (op.getValueOrNull()) 1654 return op.emitOpError("cannot have both initializer value and region"); 1655 } 1656 1657 if (op.linkage() == Linkage::Common) { 1658 if (Attribute value = op.getValueOrNull()) { 1659 if (!isZeroAttribute(value)) { 1660 return op.emitOpError() 1661 << "expected zero value for '" 1662 << stringifyLinkage(Linkage::Common) << "' linkage"; 1663 } 1664 } 1665 } 1666 1667 if (op.linkage() == Linkage::Appending) { 1668 if (!op.getType().isa<LLVMArrayType>()) { 1669 return op.emitOpError() 1670 << "expected array type for '" 1671 << stringifyLinkage(Linkage::Appending) << "' linkage"; 1672 } 1673 } 1674 1675 return success(); 1676 } 1677 1678 //===----------------------------------------------------------------------===// 1679 // Printing/parsing for LLVM::ShuffleVectorOp. 1680 //===----------------------------------------------------------------------===// 1681 // Expects vector to be of wrapped LLVM vector type and position to be of 1682 // wrapped LLVM i32 type. 1683 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 1684 Value v1, Value v2, ArrayAttr mask, 1685 ArrayRef<NamedAttribute> attrs) { 1686 auto containerType = v1.getType(); 1687 auto vType = LLVM::getFixedVectorType( 1688 LLVM::getVectorElementType(containerType), mask.size()); 1689 build(b, result, vType, v1, v2, mask); 1690 result.addAttributes(attrs); 1691 } 1692 1693 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { 1694 p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " " 1695 << op.mask(); 1696 p.printOptionalAttrDict(op->getAttrs(), {"mask"}); 1697 p << " : " << op.v1().getType() << ", " << op.v2().getType(); 1698 } 1699 1700 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 1701 // `[` integer-literal (`,` integer-literal)* `]` 1702 // attribute-dict? `:` type 1703 static ParseResult parseShuffleVectorOp(OpAsmParser &parser, 1704 OperationState &result) { 1705 llvm::SMLoc loc; 1706 OpAsmParser::OperandType v1, v2; 1707 ArrayAttr maskAttr; 1708 Type typeV1, typeV2; 1709 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 1710 parser.parseComma() || parser.parseOperand(v2) || 1711 parser.parseAttribute(maskAttr, "mask", result.attributes) || 1712 parser.parseOptionalAttrDict(result.attributes) || 1713 parser.parseColonType(typeV1) || parser.parseComma() || 1714 parser.parseType(typeV2) || 1715 parser.resolveOperand(v1, typeV1, result.operands) || 1716 parser.resolveOperand(v2, typeV2, result.operands)) 1717 return failure(); 1718 if (!LLVM::isCompatibleVectorType(typeV1)) 1719 return parser.emitError( 1720 loc, "expected LLVM IR dialect vector type for operand #1"); 1721 auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1), 1722 maskAttr.size()); 1723 result.addTypes(vType); 1724 return success(); 1725 } 1726 1727 //===----------------------------------------------------------------------===// 1728 // Implementations for LLVM::LLVMFuncOp. 1729 //===----------------------------------------------------------------------===// 1730 1731 // Add the entry block to the function. 1732 Block *LLVMFuncOp::addEntryBlock() { 1733 assert(empty() && "function already has an entry block"); 1734 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 1735 1736 auto *entry = new Block; 1737 push_back(entry); 1738 1739 LLVMFunctionType type = getType(); 1740 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 1741 entry->addArgument(type.getParamType(i)); 1742 return entry; 1743 } 1744 1745 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 1746 StringRef name, Type type, LLVM::Linkage linkage, 1747 ArrayRef<NamedAttribute> attrs, 1748 ArrayRef<DictionaryAttr> argAttrs) { 1749 result.addRegion(); 1750 result.addAttribute(SymbolTable::getSymbolAttrName(), 1751 builder.getStringAttr(name)); 1752 result.addAttribute("type", TypeAttr::get(type)); 1753 result.addAttribute(getLinkageAttrName(), 1754 builder.getI64IntegerAttr(static_cast<int64_t>(linkage))); 1755 result.attributes.append(attrs.begin(), attrs.end()); 1756 if (argAttrs.empty()) 1757 return; 1758 1759 unsigned numInputs = type.cast<LLVMFunctionType>().getNumParams(); 1760 assert(numInputs == argAttrs.size() && 1761 "expected as many argument attribute lists as arguments"); 1762 SmallString<8> argAttrName; 1763 for (unsigned i = 0; i < numInputs; ++i) 1764 if (DictionaryAttr argDict = argAttrs[i]) 1765 result.addAttribute(getArgAttrName(i, argAttrName), argDict); 1766 } 1767 1768 // Builds an LLVM function type from the given lists of input and output types. 1769 // Returns a null type if any of the types provided are non-LLVM types, or if 1770 // there is more than one output type. 1771 static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, 1772 ArrayRef<Type> inputs, ArrayRef<Type> outputs, 1773 impl::VariadicFlag variadicFlag) { 1774 Builder &b = parser.getBuilder(); 1775 if (outputs.size() > 1) { 1776 parser.emitError(loc, "failed to construct function type: expected zero or " 1777 "one function result"); 1778 return {}; 1779 } 1780 1781 // Convert inputs to LLVM types, exit early on error. 1782 SmallVector<Type, 4> llvmInputs; 1783 for (auto t : inputs) { 1784 if (!isCompatibleType(t)) { 1785 parser.emitError(loc, "failed to construct function type: expected LLVM " 1786 "type for function arguments"); 1787 return {}; 1788 } 1789 llvmInputs.push_back(t); 1790 } 1791 1792 // No output is denoted as "void" in LLVM type system. 1793 Type llvmOutput = 1794 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 1795 if (!isCompatibleType(llvmOutput)) { 1796 parser.emitError(loc, "failed to construct function type: expected LLVM " 1797 "type for function results") 1798 << llvmOutput; 1799 return {}; 1800 } 1801 return LLVMFunctionType::get(llvmOutput, llvmInputs, 1802 variadicFlag.isVariadic()); 1803 } 1804 1805 // Parses an LLVM function. 1806 // 1807 // operation ::= `llvm.func` linkage? function-signature function-attributes? 1808 // function-body 1809 // 1810 static ParseResult parseLLVMFuncOp(OpAsmParser &parser, 1811 OperationState &result) { 1812 // Default to external linkage if no keyword is provided. 1813 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, 1814 getLinkageAttrName()))) 1815 result.addAttribute(getLinkageAttrName(), 1816 parser.getBuilder().getI64IntegerAttr( 1817 static_cast<int64_t>(LLVM::Linkage::External))); 1818 1819 StringAttr nameAttr; 1820 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 1821 SmallVector<NamedAttrList, 1> argAttrs; 1822 SmallVector<NamedAttrList, 1> resultAttrs; 1823 SmallVector<Type, 8> argTypes; 1824 SmallVector<Type, 4> resultTypes; 1825 bool isVariadic; 1826 1827 auto signatureLocation = parser.getCurrentLocation(); 1828 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1829 result.attributes) || 1830 impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs, 1831 argTypes, argAttrs, isVariadic, resultTypes, 1832 resultAttrs)) 1833 return failure(); 1834 1835 auto type = 1836 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 1837 impl::VariadicFlag(isVariadic)); 1838 if (!type) 1839 return failure(); 1840 result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type)); 1841 1842 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 1843 return failure(); 1844 impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs, 1845 resultAttrs); 1846 1847 auto *body = result.addRegion(); 1848 OptionalParseResult parseResult = parser.parseOptionalRegion( 1849 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 1850 return failure(parseResult.hasValue() && failed(*parseResult)); 1851 } 1852 1853 // Print the LLVMFuncOp. Collects argument and result types and passes them to 1854 // helper functions. Drops "void" result since it cannot be parsed back. Skips 1855 // the external linkage since it is the default value. 1856 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { 1857 p << op.getOperationName() << ' '; 1858 if (op.linkage() != LLVM::Linkage::External) 1859 p << stringifyLinkage(op.linkage()) << ' '; 1860 p.printSymbolName(op.getName()); 1861 1862 LLVMFunctionType fnType = op.getType(); 1863 SmallVector<Type, 8> argTypes; 1864 SmallVector<Type, 1> resTypes; 1865 argTypes.reserve(fnType.getNumParams()); 1866 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 1867 argTypes.push_back(fnType.getParamType(i)); 1868 1869 Type returnType = fnType.getReturnType(); 1870 if (!returnType.isa<LLVMVoidType>()) 1871 resTypes.push_back(returnType); 1872 1873 impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes); 1874 impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(), 1875 {getLinkageAttrName()}); 1876 1877 // Print the body if this is not an external function. 1878 Region &body = op.body(); 1879 if (!body.empty()) 1880 p.printRegion(body, /*printEntryBlockArgs=*/false, 1881 /*printBlockTerminators=*/true); 1882 } 1883 1884 // Hook for OpTrait::FunctionLike, called after verifying that the 'type' 1885 // attribute is present. This can check for preconditions of the 1886 // getNumArguments hook not failing. 1887 LogicalResult LLVMFuncOp::verifyType() { 1888 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>(); 1889 if (!llvmType) 1890 return emitOpError("requires '" + getTypeAttrName() + 1891 "' attribute of wrapped LLVM function type"); 1892 1893 return success(); 1894 } 1895 1896 // Hook for OpTrait::FunctionLike, returns the number of function arguments. 1897 // Depends on the type attribute being correct as checked by verifyType 1898 unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); } 1899 1900 // Hook for OpTrait::FunctionLike, returns the number of function results. 1901 // Depends on the type attribute being correct as checked by verifyType 1902 unsigned LLVMFuncOp::getNumFuncResults() { 1903 // We model LLVM functions that return void as having zero results, 1904 // and all others as having one result. 1905 // If we modeled a void return as one result, then it would be possible to 1906 // attach an MLIR result attribute to it, and it isn't clear what semantics we 1907 // would assign to that. 1908 if (getType().getReturnType().isa<LLVMVoidType>()) 1909 return 0; 1910 return 1; 1911 } 1912 1913 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 1914 // - functions don't have 'common' linkage 1915 // - external functions have 'external' or 'extern_weak' linkage; 1916 // - vararg is (currently) only supported for external functions; 1917 // - entry block arguments are of LLVM types and match the function signature. 1918 static LogicalResult verify(LLVMFuncOp op) { 1919 if (op.linkage() == LLVM::Linkage::Common) 1920 return op.emitOpError() 1921 << "functions cannot have '" 1922 << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; 1923 1924 if (op.isExternal()) { 1925 if (op.linkage() != LLVM::Linkage::External && 1926 op.linkage() != LLVM::Linkage::ExternWeak) 1927 return op.emitOpError() 1928 << "external functions must have '" 1929 << stringifyLinkage(LLVM::Linkage::External) << "' or '" 1930 << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; 1931 return success(); 1932 } 1933 1934 if (op.isVarArg()) 1935 return op.emitOpError("only external functions can be variadic"); 1936 1937 unsigned numArguments = op.getType().getNumParams(); 1938 Block &entryBlock = op.front(); 1939 for (unsigned i = 0; i < numArguments; ++i) { 1940 Type argType = entryBlock.getArgument(i).getType(); 1941 if (!isCompatibleType(argType)) 1942 return op.emitOpError("entry block argument #") 1943 << i << " is not of LLVM type"; 1944 if (op.getType().getParamType(i) != argType) 1945 return op.emitOpError("the type of entry block argument #") 1946 << i << " does not match the function signature"; 1947 } 1948 1949 return success(); 1950 } 1951 1952 //===----------------------------------------------------------------------===// 1953 // Verification for LLVM::ConstantOp. 1954 //===----------------------------------------------------------------------===// 1955 1956 static LogicalResult verify(LLVM::ConstantOp op) { 1957 if (StringAttr sAttr = op.value().dyn_cast<StringAttr>()) { 1958 auto arrayType = op.getType().dyn_cast<LLVMArrayType>(); 1959 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 1960 !arrayType.getElementType().isInteger(8)) { 1961 return op->emitOpError() 1962 << "expected array type of " << sAttr.getValue().size() 1963 << " i8 elements for the string constant"; 1964 } 1965 return success(); 1966 } 1967 if (!op.value().isa<IntegerAttr, FloatAttr, ElementsAttr>()) 1968 return op.emitOpError() 1969 << "only supports integer, float, string or elements attributes"; 1970 return success(); 1971 } 1972 1973 //===----------------------------------------------------------------------===// 1974 // Utility functions for parsing atomic ops 1975 //===----------------------------------------------------------------------===// 1976 1977 // Helper function to parse a keyword into the specified attribute named by 1978 // `attrName`. The keyword must match one of the string values defined by the 1979 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 1980 // state. 1981 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 1982 StringRef attrName) { 1983 llvm::SMLoc loc; 1984 StringRef keyword; 1985 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 1986 return failure(); 1987 1988 // Replace the keyword `keyword` with an integer attribute. 1989 auto kind = symbolizeAtomicBinOp(keyword); 1990 if (!kind) { 1991 return parser.emitError(loc) 1992 << "'" << keyword << "' is an incorrect value of the '" << attrName 1993 << "' attribute"; 1994 } 1995 1996 auto value = static_cast<int64_t>(kind.getValue()); 1997 auto attr = parser.getBuilder().getI64IntegerAttr(value); 1998 result.addAttribute(attrName, attr); 1999 2000 return success(); 2001 } 2002 2003 // Helper function to parse a keyword into the specified attribute named by 2004 // `attrName`. The keyword must match one of the string values defined by the 2005 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2006 // state. 2007 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2008 OperationState &result, 2009 StringRef attrName) { 2010 llvm::SMLoc loc; 2011 StringRef ordering; 2012 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2013 return failure(); 2014 2015 // Replace the keyword `ordering` with an integer attribute. 2016 auto kind = symbolizeAtomicOrdering(ordering); 2017 if (!kind) { 2018 return parser.emitError(loc) 2019 << "'" << ordering << "' is an incorrect value of the '" << attrName 2020 << "' attribute"; 2021 } 2022 2023 auto value = static_cast<int64_t>(kind.getValue()); 2024 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2025 result.addAttribute(attrName, attr); 2026 2027 return success(); 2028 } 2029 2030 //===----------------------------------------------------------------------===// 2031 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2032 //===----------------------------------------------------------------------===// 2033 2034 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { 2035 p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' ' 2036 << op.ptr() << ", " << op.val() << ' ' 2037 << stringifyAtomicOrdering(op.ordering()) << ' '; 2038 p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"}); 2039 p << " : " << op.res().getType(); 2040 } 2041 2042 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2043 // attribute-dict? `:` type 2044 static ParseResult parseAtomicRMWOp(OpAsmParser &parser, 2045 OperationState &result) { 2046 Type type; 2047 OpAsmParser::OperandType ptr, val; 2048 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2049 parser.parseComma() || parser.parseOperand(val) || 2050 parseAtomicOrdering(parser, result, "ordering") || 2051 parser.parseOptionalAttrDict(result.attributes) || 2052 parser.parseColonType(type) || 2053 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2054 result.operands) || 2055 parser.resolveOperand(val, type, result.operands)) 2056 return failure(); 2057 2058 result.addTypes(type); 2059 return success(); 2060 } 2061 2062 static LogicalResult verify(AtomicRMWOp op) { 2063 auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>(); 2064 auto valType = op.val().getType(); 2065 if (valType != ptrType.getElementType()) 2066 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2067 "match type for operand #1"); 2068 auto resType = op.res().getType(); 2069 if (resType != valType) 2070 return op.emitOpError( 2071 "expected LLVM IR result type to match type for operand #1"); 2072 if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { 2073 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2074 return op.emitOpError("expected LLVM IR floating point type"); 2075 } else if (op.bin_op() == AtomicBinOp::xchg) { 2076 auto intType = valType.dyn_cast<IntegerType>(); 2077 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2078 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2079 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2080 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2081 !valType.isa<Float64Type>()) 2082 return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2083 } else { 2084 auto intType = valType.dyn_cast<IntegerType>(); 2085 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2086 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2087 intBitWidth != 64) 2088 return op.emitOpError("expected LLVM IR integer type"); 2089 } 2090 2091 if (static_cast<unsigned>(op.ordering()) < 2092 static_cast<unsigned>(AtomicOrdering::monotonic)) 2093 return op.emitOpError() 2094 << "expected at least '" 2095 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2096 << "' ordering"; 2097 2098 return success(); 2099 } 2100 2101 //===----------------------------------------------------------------------===// 2102 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2103 //===----------------------------------------------------------------------===// 2104 2105 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { 2106 p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", " 2107 << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' ' 2108 << stringifyAtomicOrdering(op.failure_ordering()); 2109 p.printOptionalAttrDict(op->getAttrs(), 2110 {"success_ordering", "failure_ordering"}); 2111 p << " : " << op.val().getType(); 2112 } 2113 2114 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2115 // keyword keyword attribute-dict? `:` type 2116 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, 2117 OperationState &result) { 2118 auto &builder = parser.getBuilder(); 2119 Type type; 2120 OpAsmParser::OperandType ptr, cmp, val; 2121 if (parser.parseOperand(ptr) || parser.parseComma() || 2122 parser.parseOperand(cmp) || parser.parseComma() || 2123 parser.parseOperand(val) || 2124 parseAtomicOrdering(parser, result, "success_ordering") || 2125 parseAtomicOrdering(parser, result, "failure_ordering") || 2126 parser.parseOptionalAttrDict(result.attributes) || 2127 parser.parseColonType(type) || 2128 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2129 result.operands) || 2130 parser.resolveOperand(cmp, type, result.operands) || 2131 parser.resolveOperand(val, type, result.operands)) 2132 return failure(); 2133 2134 auto boolType = IntegerType::get(builder.getContext(), 1); 2135 auto resultType = 2136 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2137 result.addTypes(resultType); 2138 2139 return success(); 2140 } 2141 2142 static LogicalResult verify(AtomicCmpXchgOp op) { 2143 auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>(); 2144 if (!ptrType) 2145 return op.emitOpError("expected LLVM IR pointer type for operand #0"); 2146 auto cmpType = op.cmp().getType(); 2147 auto valType = op.val().getType(); 2148 if (cmpType != ptrType.getElementType() || cmpType != valType) 2149 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2150 "match type for all other operands"); 2151 auto intType = valType.dyn_cast<IntegerType>(); 2152 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2153 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2154 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2155 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2156 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2157 return op.emitOpError("unexpected LLVM IR type"); 2158 if (op.success_ordering() < AtomicOrdering::monotonic || 2159 op.failure_ordering() < AtomicOrdering::monotonic) 2160 return op.emitOpError("ordering must be at least 'monotonic'"); 2161 if (op.failure_ordering() == AtomicOrdering::release || 2162 op.failure_ordering() == AtomicOrdering::acq_rel) 2163 return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2164 return success(); 2165 } 2166 2167 //===----------------------------------------------------------------------===// 2168 // Printer, parser and verifier for LLVM::FenceOp. 2169 //===----------------------------------------------------------------------===// 2170 2171 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2172 // attribute-dict? 2173 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) { 2174 StringAttr sScope; 2175 StringRef syncscopeKeyword = "syncscope"; 2176 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2177 if (parser.parseLParen() || 2178 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2179 parser.parseRParen()) 2180 return failure(); 2181 } else { 2182 result.addAttribute(syncscopeKeyword, 2183 parser.getBuilder().getStringAttr("")); 2184 } 2185 if (parseAtomicOrdering(parser, result, "ordering") || 2186 parser.parseOptionalAttrDict(result.attributes)) 2187 return failure(); 2188 return success(); 2189 } 2190 2191 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { 2192 StringRef syncscopeKeyword = "syncscope"; 2193 p << op.getOperationName() << ' '; 2194 if (!op->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2195 p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") "; 2196 p << stringifyAtomicOrdering(op.ordering()); 2197 } 2198 2199 static LogicalResult verify(FenceOp &op) { 2200 if (op.ordering() == AtomicOrdering::not_atomic || 2201 op.ordering() == AtomicOrdering::unordered || 2202 op.ordering() == AtomicOrdering::monotonic) 2203 return op.emitOpError("can be given only acquire, release, acq_rel, " 2204 "and seq_cst orderings"); 2205 return success(); 2206 } 2207 2208 //===----------------------------------------------------------------------===// 2209 // LLVMDialect initialization, type parsing, and registration. 2210 //===----------------------------------------------------------------------===// 2211 2212 void LLVMDialect::initialize() { 2213 addAttributes<FMFAttr, LoopOptionAttr>(); 2214 2215 // clang-format off 2216 addTypes<LLVMVoidType, 2217 LLVMPPCFP128Type, 2218 LLVMX86MMXType, 2219 LLVMTokenType, 2220 LLVMLabelType, 2221 LLVMMetadataType, 2222 LLVMFunctionType, 2223 LLVMPointerType, 2224 LLVMFixedVectorType, 2225 LLVMScalableVectorType, 2226 LLVMArrayType, 2227 LLVMStructType>(); 2228 // clang-format on 2229 addOperations< 2230 #define GET_OP_LIST 2231 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2232 >(); 2233 2234 // Support unknown operations because not all LLVM operations are registered. 2235 allowUnknownOperations(); 2236 } 2237 2238 #define GET_OP_CLASSES 2239 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2240 2241 /// Parse a type registered to this dialect. 2242 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2243 return detail::parseType(parser); 2244 } 2245 2246 /// Print a type registered to this dialect. 2247 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2248 return detail::printType(type, os); 2249 } 2250 2251 LogicalResult LLVMDialect::verifyDataLayoutString( 2252 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2253 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2254 llvm::DataLayout::parse(descr); 2255 if (maybeDataLayout) 2256 return success(); 2257 2258 std::string message; 2259 llvm::raw_string_ostream messageStream(message); 2260 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2261 reportError("invalid data layout descriptor: " + messageStream.str()); 2262 return failure(); 2263 } 2264 2265 /// Verify LLVM dialect attributes. 2266 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2267 NamedAttribute attr) { 2268 // If the `llvm.loop` attribute is present, enforce the following structure, 2269 // which the module translation can assume. 2270 if (attr.first.strref() == LLVMDialect::getLoopAttrName()) { 2271 auto loopAttr = attr.second.dyn_cast<DictionaryAttr>(); 2272 if (!loopAttr) 2273 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2274 << "' to be a dictionary attribute"; 2275 Optional<NamedAttribute> parallelAccessGroup = 2276 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2277 if (parallelAccessGroup.hasValue()) { 2278 auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>(); 2279 if (!accessGroups) 2280 return op->emitOpError() 2281 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2282 << "' to be an array attribute"; 2283 for (Attribute attr : accessGroups) { 2284 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2285 if (!accessGroupRef) 2286 return op->emitOpError() 2287 << "expected '" << attr << "' to be a symbol reference"; 2288 StringRef metadataName = accessGroupRef.getRootReference(); 2289 auto metadataOp = 2290 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2291 op->getParentOp(), metadataName); 2292 if (!metadataOp) 2293 return op->emitOpError() 2294 << "expected '" << attr << "' to reference a metadata op"; 2295 StringRef accessGroupName = accessGroupRef.getLeafReference(); 2296 Operation *accessGroupOp = 2297 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2298 if (!accessGroupOp) 2299 return op->emitOpError() 2300 << "expected '" << attr << "' to reference an access_group op"; 2301 } 2302 } 2303 2304 Optional<NamedAttribute> loopOptions = 2305 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2306 if (loopOptions.hasValue()) { 2307 auto options = loopOptions->second.dyn_cast<ArrayAttr>(); 2308 if (!options) 2309 return op->emitOpError() 2310 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2311 << "' to be an array attribute"; 2312 if (!llvm::all_of(options, [](Attribute option) { 2313 return option.isa<LoopOptionAttr>(); 2314 })) 2315 return op->emitOpError() << "invalid loop options list " << options; 2316 } 2317 } 2318 2319 // If the data layout attribute is present, it must use the LLVM data layout 2320 // syntax. Try parsing it and report errors in case of failure. Users of this 2321 // attribute may assume it is well-formed and can pass it to the (asserting) 2322 // llvm::DataLayout constructor. 2323 if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2324 return success(); 2325 if (auto stringAttr = attr.second.dyn_cast<StringAttr>()) 2326 return verifyDataLayoutString( 2327 stringAttr.getValue(), 2328 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2329 2330 return op->emitOpError() << "expected '" 2331 << LLVM::LLVMDialect::getDataLayoutAttrName() 2332 << "' to be a string attribute"; 2333 } 2334 2335 /// Verify LLVMIR function argument attributes. 2336 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2337 unsigned regionIdx, 2338 unsigned argIdx, 2339 NamedAttribute argAttr) { 2340 // Check that llvm.noalias is a boolean attribute. 2341 if (argAttr.first == LLVMDialect::getNoAliasAttrName() && 2342 !argAttr.second.isa<BoolAttr>()) 2343 return op->emitError() 2344 << "llvm.noalias argument attribute of non boolean type"; 2345 // Check that llvm.align is an integer attribute. 2346 if (argAttr.first == LLVMDialect::getAlignAttrName() && 2347 !argAttr.second.isa<IntegerAttr>()) 2348 return op->emitError() 2349 << "llvm.align argument attribute of non integer type"; 2350 return success(); 2351 } 2352 2353 //===----------------------------------------------------------------------===// 2354 // Utility functions. 2355 //===----------------------------------------------------------------------===// 2356 2357 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2358 StringRef name, StringRef value, 2359 LLVM::Linkage linkage) { 2360 assert(builder.getInsertionBlock() && 2361 builder.getInsertionBlock()->getParentOp() && 2362 "expected builder to point to a block constrained in an op"); 2363 auto module = 2364 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2365 assert(module && "builder points to an op outside of a module"); 2366 2367 // Create the global at the entry of the module. 2368 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2369 MLIRContext *ctx = builder.getContext(); 2370 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2371 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2372 loc, type, /*isConstant=*/true, linkage, name, 2373 builder.getStringAttr(value)); 2374 2375 // Get the pointer to the first character in the global string. 2376 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2377 Value cst0 = builder.create<LLVM::ConstantOp>( 2378 loc, IntegerType::get(ctx, 64), 2379 builder.getIntegerAttr(builder.getIndexType(), 0)); 2380 return builder.create<LLVM::GEPOp>( 2381 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2382 ValueRange{cst0, cst0}); 2383 } 2384 2385 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2386 return op->hasTrait<OpTrait::SymbolTable>() && 2387 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2388 } 2389 2390 FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) { 2391 return Base::get(context, static_cast<uint64_t>(flags)); 2392 } 2393 2394 FastmathFlags FMFAttr::getFlags() const { 2395 return static_cast<FastmathFlags>(getImpl()->value); 2396 } 2397 2398 static constexpr const FastmathFlags FastmathFlagsList[] = { 2399 // clang-format off 2400 FastmathFlags::nnan, 2401 FastmathFlags::ninf, 2402 FastmathFlags::nsz, 2403 FastmathFlags::arcp, 2404 FastmathFlags::contract, 2405 FastmathFlags::afn, 2406 FastmathFlags::reassoc, 2407 FastmathFlags::fast, 2408 // clang-format on 2409 }; 2410 2411 void FMFAttr::print(DialectAsmPrinter &printer) const { 2412 printer << "fastmath<"; 2413 auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) { 2414 return bitEnumContains(this->getFlags(), flag); 2415 }); 2416 llvm::interleaveComma(flags, printer, 2417 [&](auto flag) { printer << stringifyEnum(flag); }); 2418 printer << ">"; 2419 } 2420 2421 Attribute FMFAttr::parse(DialectAsmParser &parser) { 2422 if (failed(parser.parseLess())) 2423 return {}; 2424 2425 FastmathFlags flags = {}; 2426 if (failed(parser.parseOptionalGreater())) { 2427 do { 2428 StringRef elemName; 2429 if (failed(parser.parseKeyword(&elemName))) 2430 return {}; 2431 2432 auto elem = symbolizeFastmathFlags(elemName); 2433 if (!elem) { 2434 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2435 << elemName; 2436 return {}; 2437 } 2438 2439 flags = flags | *elem; 2440 } while (succeeded(parser.parseOptionalComma())); 2441 2442 if (failed(parser.parseGreater())) 2443 return {}; 2444 } 2445 2446 return FMFAttr::get(flags, parser.getBuilder().getContext()); 2447 } 2448 2449 LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context, 2450 bool disable) { 2451 auto option = LoopOptionCase::disable_unroll; 2452 return Base::get(context, static_cast<uint64_t>(option), 2453 static_cast<int32_t>(disable)); 2454 } 2455 2456 LoopOptionAttr LoopOptionAttr::getDisableLICM(MLIRContext *context, 2457 bool disable) { 2458 auto option = LoopOptionCase::disable_licm; 2459 return Base::get(context, static_cast<uint64_t>(option), 2460 static_cast<int32_t>(disable)); 2461 } 2462 2463 LoopOptionAttr LoopOptionAttr::getInterleaveCount(MLIRContext *context, 2464 int32_t count) { 2465 auto option = LoopOptionCase::interleave_count; 2466 return Base::get(context, static_cast<uint64_t>(option), 2467 static_cast<int32_t>(count)); 2468 } 2469 2470 LoopOptionCase LoopOptionAttr::getCase() const { 2471 return static_cast<LoopOptionCase>(getImpl()->option); 2472 } 2473 2474 bool LoopOptionAttr::getBool() const { 2475 LoopOptionCase option = getCase(); 2476 (void)option; 2477 assert(option == LoopOptionCase::disable_licm || 2478 option == LoopOptionCase::disable_unroll && 2479 "expected a boolean loop option"); 2480 return static_cast<bool>(getImpl()->value); 2481 } 2482 2483 int32_t LoopOptionAttr::getInt() const { 2484 LoopOptionCase option = getCase(); 2485 (void)option; 2486 assert(option == LoopOptionCase::interleave_count && 2487 "expected an integer loop option"); 2488 return getImpl()->value; 2489 } 2490 2491 void LoopOptionAttr::print(DialectAsmPrinter &printer) const { 2492 printer << "loopopt<" << stringifyEnum(getCase()) << " = "; 2493 switch (getCase()) { 2494 case LoopOptionCase::disable_licm: 2495 case LoopOptionCase::disable_unroll: 2496 printer << (getBool() ? "true" : "false"); 2497 break; 2498 case LoopOptionCase::interleave_count: 2499 printer << getInt(); 2500 break; 2501 } 2502 printer << ">"; 2503 } 2504 2505 Attribute LoopOptionAttr::parse(DialectAsmParser &parser) { 2506 if (failed(parser.parseLess())) 2507 return {}; 2508 2509 StringRef optionName; 2510 if (failed(parser.parseKeyword(&optionName))) 2511 return {}; 2512 2513 auto option = symbolizeLoopOptionCase(optionName); 2514 if (!option) { 2515 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 2516 << optionName; 2517 return {}; 2518 } 2519 2520 if (failed(parser.parseEqual())) 2521 return {}; 2522 2523 int32_t value; 2524 switch (*option) { 2525 case LoopOptionCase::disable_licm: 2526 case LoopOptionCase::disable_unroll: 2527 if (succeeded(parser.parseOptionalKeyword("true"))) 2528 value = 1; 2529 else if (succeeded(parser.parseOptionalKeyword("false"))) 2530 value = 0; 2531 else { 2532 parser.emitError(parser.getNameLoc(), 2533 "expected boolean value 'true' or 'false'"); 2534 return {}; 2535 } 2536 break; 2537 case LoopOptionCase::interleave_count: 2538 if (failed(parser.parseInteger(value))) { 2539 parser.emitError(parser.getNameLoc(), "expected integer value"); 2540 return {}; 2541 } 2542 break; 2543 } 2544 2545 if (failed(parser.parseGreater())) 2546 return {}; 2547 2548 return Base::get(parser.getBuilder().getContext(), 2549 static_cast<uint64_t>(*option), value); 2550 } 2551 2552 Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, 2553 Type type) const { 2554 if (type) { 2555 parser.emitError(parser.getNameLoc(), "unexpected type"); 2556 return {}; 2557 } 2558 StringRef attrKind; 2559 if (parser.parseKeyword(&attrKind)) 2560 return {}; 2561 2562 if (attrKind == "fastmath") 2563 return FMFAttr::parse(parser); 2564 2565 if (attrKind == "loopopt") 2566 return LoopOptionAttr::parse(parser); 2567 2568 parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind; 2569 return {}; 2570 } 2571 2572 void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { 2573 if (auto fmf = attr.dyn_cast<FMFAttr>()) 2574 fmf.print(os); 2575 else if (auto lopt = attr.dyn_cast<LoopOptionAttr>()) 2576 lopt.print(os); 2577 else 2578 llvm_unreachable("Unknown attribute type"); 2579 } 2580