1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 23 #include "llvm/ADT/StringSwitch.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/AsmParser/Parser.h" 26 #include "llvm/Bitcode/BitcodeReader.h" 27 #include "llvm/Bitcode/BitcodeWriter.h" 28 #include "llvm/IR/Attributes.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/Support/Mutex.h" 32 #include "llvm/Support/SourceMgr.h" 33 34 #include <iostream> 35 36 using namespace mlir; 37 using namespace mlir::LLVM; 38 39 static constexpr const char kVolatileAttrName[] = "volatile_"; 40 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 41 42 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 43 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 44 #define GET_ATTRDEF_CLASSES 45 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 46 47 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 48 SmallVector<NamedAttribute, 8> filteredAttrs( 49 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 50 if (attr.first == "fastmathFlags") { 51 auto defAttr = FMFAttr::get(attr.second.getContext(), {}); 52 return defAttr != attr.second; 53 } 54 return true; 55 })); 56 return filteredAttrs; 57 } 58 59 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 60 NamedAttrList &result) { 61 return parser.parseOptionalAttrDict(result); 62 } 63 64 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 65 DictionaryAttr attrs) { 66 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 67 } 68 69 //===----------------------------------------------------------------------===// 70 // Printing/parsing for LLVM::CmpOp. 71 //===----------------------------------------------------------------------===// 72 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { 73 p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) 74 << "\" " << op.getOperand(0) << ", " << op.getOperand(1); 75 p.printOptionalAttrDict(op->getAttrs(), {"predicate"}); 76 p << " : " << op.lhs().getType(); 77 } 78 79 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { 80 p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) 81 << "\" " << op.getOperand(0) << ", " << op.getOperand(1); 82 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"}); 83 p << " : " << op.lhs().getType(); 84 } 85 86 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 87 // attribute-dict? `:` type 88 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 89 // attribute-dict? `:` type 90 template <typename CmpPredicateType> 91 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 92 Builder &builder = parser.getBuilder(); 93 94 StringAttr predicateAttr; 95 OpAsmParser::OperandType lhs, rhs; 96 Type type; 97 llvm::SMLoc predicateLoc, trailingTypeLoc; 98 if (parser.getCurrentLocation(&predicateLoc) || 99 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 100 parser.parseOperand(lhs) || parser.parseComma() || 101 parser.parseOperand(rhs) || 102 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 103 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 104 parser.resolveOperand(lhs, type, result.operands) || 105 parser.resolveOperand(rhs, type, result.operands)) 106 return failure(); 107 108 // Replace the string attribute `predicate` with an integer attribute. 109 int64_t predicateValue = 0; 110 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 111 Optional<ICmpPredicate> predicate = 112 symbolizeICmpPredicate(predicateAttr.getValue()); 113 if (!predicate) 114 return parser.emitError(predicateLoc) 115 << "'" << predicateAttr.getValue() 116 << "' is an incorrect value of the 'predicate' attribute"; 117 predicateValue = static_cast<int64_t>(predicate.getValue()); 118 } else { 119 Optional<FCmpPredicate> predicate = 120 symbolizeFCmpPredicate(predicateAttr.getValue()); 121 if (!predicate) 122 return parser.emitError(predicateLoc) 123 << "'" << predicateAttr.getValue() 124 << "' is an incorrect value of the 'predicate' attribute"; 125 predicateValue = static_cast<int64_t>(predicate.getValue()); 126 } 127 128 result.attributes.set("predicate", 129 parser.getBuilder().getI64IntegerAttr(predicateValue)); 130 131 // The result type is either i1 or a vector type <? x i1> if the inputs are 132 // vectors. 133 Type resultType = IntegerType::get(builder.getContext(), 1); 134 if (!isCompatibleType(type)) 135 return parser.emitError(trailingTypeLoc, 136 "expected LLVM dialect-compatible type"); 137 if (LLVM::isCompatibleVectorType(type)) { 138 if (type.isa<LLVM::LLVMScalableVectorType>()) { 139 resultType = LLVM::LLVMScalableVectorType::get( 140 resultType, LLVM::getVectorNumElements(type).getKnownMinValue()); 141 } else { 142 resultType = LLVM::getFixedVectorType( 143 resultType, LLVM::getVectorNumElements(type).getFixedValue()); 144 } 145 } 146 147 result.addTypes({resultType}); 148 return success(); 149 } 150 151 //===----------------------------------------------------------------------===// 152 // Printing/parsing for LLVM::AllocaOp. 153 //===----------------------------------------------------------------------===// 154 155 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { 156 auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType(); 157 158 auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()}, 159 {op.getType()}); 160 161 p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; 162 if (op.alignment().hasValue() && *op.alignment() != 0) 163 p.printOptionalAttrDict(op->getAttrs()); 164 else 165 p.printOptionalAttrDict(op->getAttrs(), {"alignment"}); 166 p << " : " << funcTy; 167 } 168 169 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 170 // `:` type `,` type 171 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { 172 OpAsmParser::OperandType arraySize; 173 Type type, elemType; 174 llvm::SMLoc trailingTypeLoc; 175 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 176 parser.parseType(elemType) || 177 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 178 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 179 return failure(); 180 181 Optional<NamedAttribute> alignmentAttr = 182 result.attributes.getNamed("alignment"); 183 if (alignmentAttr.hasValue()) { 184 auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>(); 185 if (!alignmentInt) 186 return parser.emitError(parser.getNameLoc(), 187 "expected integer alignment"); 188 if (alignmentInt.getValue().isNullValue()) 189 result.attributes.erase("alignment"); 190 } 191 192 // Extract the result type from the trailing function type. 193 auto funcType = type.dyn_cast<FunctionType>(); 194 if (!funcType || funcType.getNumInputs() != 1 || 195 funcType.getNumResults() != 1) 196 return parser.emitError( 197 trailingTypeLoc, 198 "expected trailing function type with one argument and one result"); 199 200 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 201 return failure(); 202 203 result.addTypes({funcType.getResult(0)}); 204 return success(); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // LLVM::BrOp 209 //===----------------------------------------------------------------------===// 210 211 Optional<MutableOperandRange> 212 BrOp::getMutableSuccessorOperands(unsigned index) { 213 assert(index == 0 && "invalid successor index"); 214 return destOperandsMutable(); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // LLVM::CondBrOp 219 //===----------------------------------------------------------------------===// 220 221 Optional<MutableOperandRange> 222 CondBrOp::getMutableSuccessorOperands(unsigned index) { 223 assert(index < getNumSuccessors() && "invalid successor index"); 224 return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // LLVM::SwitchOp 229 //===----------------------------------------------------------------------===// 230 231 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 232 Block *defaultDestination, ValueRange defaultOperands, 233 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 234 ArrayRef<ValueRange> caseOperands, 235 ArrayRef<int32_t> branchWeights) { 236 SmallVector<Value> flattenedCaseOperands; 237 SmallVector<int32_t> caseOperandOffsets; 238 int32_t offset = 0; 239 for (ValueRange operands : caseOperands) { 240 flattenedCaseOperands.append(operands.begin(), operands.end()); 241 caseOperandOffsets.push_back(offset); 242 offset += operands.size(); 243 } 244 ElementsAttr caseValuesAttr; 245 if (!caseValues.empty()) 246 caseValuesAttr = builder.getI32VectorAttr(caseValues); 247 ElementsAttr caseOperandOffsetsAttr; 248 if (!caseOperandOffsets.empty()) 249 caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); 250 251 ElementsAttr weightsAttr; 252 if (!branchWeights.empty()) 253 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 254 255 build(builder, result, value, defaultOperands, flattenedCaseOperands, 256 caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination, 257 caseDestinations); 258 } 259 260 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 261 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 262 static ParseResult 263 parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues, 264 SmallVectorImpl<Block *> &caseDestinations, 265 SmallVectorImpl<OpAsmParser::OperandType> &caseOperands, 266 SmallVectorImpl<Type> &caseOperandTypes, 267 ElementsAttr &caseOperandOffsets) { 268 SmallVector<int32_t> values; 269 SmallVector<int32_t> offsets; 270 int32_t value, offset = 0; 271 do { 272 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 273 if (values.empty() && !integerParseResult.hasValue()) 274 return success(); 275 276 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 277 return failure(); 278 values.push_back(value); 279 280 Block *destination; 281 SmallVector<OpAsmParser::OperandType> operands; 282 if (parser.parseColon() || parser.parseSuccessor(destination)) 283 return failure(); 284 if (!parser.parseOptionalLParen()) { 285 if (parser.parseRegionArgumentList(operands) || 286 parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen()) 287 return failure(); 288 } 289 caseDestinations.push_back(destination); 290 caseOperands.append(operands.begin(), operands.end()); 291 offsets.push_back(offset); 292 offset += operands.size(); 293 } while (!parser.parseOptionalComma()); 294 295 Builder &builder = parser.getBuilder(); 296 caseValues = builder.getI32VectorAttr(values); 297 caseOperandOffsets = builder.getI32VectorAttr(offsets); 298 299 return success(); 300 } 301 302 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, 303 ElementsAttr caseValues, 304 SuccessorRange caseDestinations, 305 OperandRange caseOperands, 306 TypeRange caseOperandTypes, 307 ElementsAttr caseOperandOffsets) { 308 if (!caseValues) 309 return; 310 311 size_t index = 0; 312 llvm::interleave( 313 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 314 [&](auto i) { 315 p << " "; 316 p << std::get<0>(i).getLimitedValue(); 317 p << ": "; 318 p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++)); 319 }, 320 [&] { 321 p << ','; 322 p.printNewline(); 323 }); 324 p.printNewline(); 325 } 326 327 static LogicalResult verify(SwitchOp op) { 328 if ((!op.case_values() && !op.caseDestinations().empty()) || 329 (op.case_values() && 330 op.case_values()->size() != 331 static_cast<int64_t>(op.caseDestinations().size()))) 332 return op.emitOpError("expects number of case values to match number of " 333 "case destinations"); 334 if (op.branch_weights() && 335 op.branch_weights()->size() != op.getNumSuccessors()) 336 return op.emitError("expects number of branch weights to match number of " 337 "successors: ") 338 << op.branch_weights()->size() << " vs " << op.getNumSuccessors(); 339 return success(); 340 } 341 342 OperandRange SwitchOp::getCaseOperands(unsigned index) { 343 return getCaseOperandsMutable(index); 344 } 345 346 MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { 347 MutableOperandRange caseOperands = caseOperandsMutable(); 348 if (!case_operand_offsets()) { 349 assert(caseOperands.size() == 0 && 350 "non-empty case operands must have offsets"); 351 return caseOperands; 352 } 353 354 ElementsAttr offsets = case_operand_offsets().getValue(); 355 assert(index < offsets.size() && "invalid case operand offset index"); 356 357 int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt(); 358 int64_t end = index + 1 == offsets.size() 359 ? caseOperands.size() 360 : offsets.getValue(index + 1).cast<IntegerAttr>().getInt(); 361 return caseOperandsMutable().slice(begin, end - begin); 362 } 363 364 Optional<MutableOperandRange> 365 SwitchOp::getMutableSuccessorOperands(unsigned index) { 366 assert(index < getNumSuccessors() && "invalid successor index"); 367 return index == 0 ? defaultOperandsMutable() 368 : getCaseOperandsMutable(index - 1); 369 } 370 371 //===----------------------------------------------------------------------===// 372 // Builder, printer and parser for for LLVM::LoadOp. 373 //===----------------------------------------------------------------------===// 374 375 static LogicalResult verifyAccessGroups(Operation *op) { 376 if (Attribute attribute = 377 op->getAttr(LLVMDialect::getAccessGroupsAttrName())) { 378 // The attribute is already verified to be a symbol ref array attribute via 379 // a constraint in the operation definition. 380 for (SymbolRefAttr accessGroupRef : 381 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 382 StringRef metadataName = accessGroupRef.getRootReference(); 383 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 384 op->getParentOp(), metadataName); 385 if (!metadataOp) 386 return op->emitOpError() << "expected '" << accessGroupRef 387 << "' to reference a metadata op"; 388 StringRef accessGroupName = accessGroupRef.getLeafReference(); 389 Operation *accessGroupOp = 390 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 391 if (!accessGroupOp) 392 return op->emitOpError() << "expected '" << accessGroupRef 393 << "' to reference an access_group op"; 394 } 395 } 396 return success(); 397 } 398 399 static LogicalResult verify(LoadOp op) { 400 return verifyAccessGroups(op.getOperation()); 401 } 402 403 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 404 Value addr, unsigned alignment, bool isVolatile, 405 bool isNonTemporal) { 406 result.addOperands(addr); 407 result.addTypes(t); 408 if (isVolatile) 409 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 410 if (isNonTemporal) 411 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 412 if (alignment != 0) 413 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 414 } 415 416 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { 417 p << op.getOperationName() << ' '; 418 if (op.volatile_()) 419 p << "volatile "; 420 p << op.addr(); 421 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 422 p << " : " << op.addr().getType(); 423 } 424 425 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 426 // the resulting type wrapped in MLIR, or nullptr on error. 427 static Type getLoadStoreElementType(OpAsmParser &parser, Type type, 428 llvm::SMLoc trailingTypeLoc) { 429 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 430 if (!llvmTy) 431 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), 432 nullptr; 433 return llvmTy.getElementType(); 434 } 435 436 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 437 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { 438 OpAsmParser::OperandType addr; 439 Type type; 440 llvm::SMLoc trailingTypeLoc; 441 442 if (succeeded(parser.parseOptionalKeyword("volatile"))) 443 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 444 445 if (parser.parseOperand(addr) || 446 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 447 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 448 parser.resolveOperand(addr, type, result.operands)) 449 return failure(); 450 451 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 452 453 result.addTypes(elemTy); 454 return success(); 455 } 456 457 //===----------------------------------------------------------------------===// 458 // Builder, printer and parser for LLVM::StoreOp. 459 //===----------------------------------------------------------------------===// 460 461 static LogicalResult verify(StoreOp op) { 462 return verifyAccessGroups(op.getOperation()); 463 } 464 465 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 466 Value addr, unsigned alignment, bool isVolatile, 467 bool isNonTemporal) { 468 result.addOperands({value, addr}); 469 result.addTypes({}); 470 if (isVolatile) 471 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 472 if (isNonTemporal) 473 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 474 if (alignment != 0) 475 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 476 } 477 478 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { 479 p << op.getOperationName() << ' '; 480 if (op.volatile_()) 481 p << "volatile "; 482 p << op.value() << ", " << op.addr(); 483 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 484 p << " : " << op.addr().getType(); 485 } 486 487 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 488 // attribute-dict? `:` type 489 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { 490 OpAsmParser::OperandType addr, value; 491 Type type; 492 llvm::SMLoc trailingTypeLoc; 493 494 if (succeeded(parser.parseOptionalKeyword("volatile"))) 495 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 496 497 if (parser.parseOperand(value) || parser.parseComma() || 498 parser.parseOperand(addr) || 499 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 500 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 501 return failure(); 502 503 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 504 if (!elemTy) 505 return failure(); 506 507 if (parser.resolveOperand(value, elemTy, result.operands) || 508 parser.resolveOperand(addr, type, result.operands)) 509 return failure(); 510 511 return success(); 512 } 513 514 ///===---------------------------------------------------------------------===// 515 /// LLVM::InvokeOp 516 ///===---------------------------------------------------------------------===// 517 518 Optional<MutableOperandRange> 519 InvokeOp::getMutableSuccessorOperands(unsigned index) { 520 assert(index < getNumSuccessors() && "invalid successor index"); 521 return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable(); 522 } 523 524 static LogicalResult verify(InvokeOp op) { 525 if (op.getNumResults() > 1) 526 return op.emitOpError("must have 0 or 1 result"); 527 528 Block *unwindDest = op.unwindDest(); 529 if (unwindDest->empty()) 530 return op.emitError( 531 "must have at least one operation in unwind destination"); 532 533 // In unwind destination, first operation must be LandingpadOp 534 if (!isa<LandingpadOp>(unwindDest->front())) 535 return op.emitError("first operation in unwind destination should be a " 536 "llvm.landingpad operation"); 537 538 return success(); 539 } 540 541 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { 542 auto callee = op.callee(); 543 bool isDirect = callee.hasValue(); 544 545 p << op.getOperationName() << ' '; 546 547 // Either function name or pointer 548 if (isDirect) 549 p.printSymbolName(callee.getValue()); 550 else 551 p << op.getOperand(0); 552 553 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 554 p << " to "; 555 p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands()); 556 p << " unwind "; 557 p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands()); 558 559 p.printOptionalAttrDict(op->getAttrs(), 560 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 561 p << " : "; 562 p.printFunctionalType( 563 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1), 564 op.getResultTypes()); 565 } 566 567 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 568 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 569 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 570 /// attribute-dict? `:` function-type 571 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { 572 SmallVector<OpAsmParser::OperandType, 8> operands; 573 FunctionType funcType; 574 SymbolRefAttr funcAttr; 575 llvm::SMLoc trailingTypeLoc; 576 Block *normalDest, *unwindDest; 577 SmallVector<Value, 4> normalOperands, unwindOperands; 578 Builder &builder = parser.getBuilder(); 579 580 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 581 // case of an indirect call, there will be 1 operand before `(`. In case of a 582 // direct call, there will be no operands and the parser will stop at the 583 // function identifier without complaining. 584 if (parser.parseOperandList(operands)) 585 return failure(); 586 bool isDirect = operands.empty(); 587 588 // Optionally parse a function identifier. 589 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 590 return failure(); 591 592 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 593 parser.parseKeyword("to") || 594 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 595 parser.parseKeyword("unwind") || 596 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 597 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 598 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 599 return failure(); 600 601 if (isDirect) { 602 // Make sure types match. 603 if (parser.resolveOperands(operands, funcType.getInputs(), 604 parser.getNameLoc(), result.operands)) 605 return failure(); 606 result.addTypes(funcType.getResults()); 607 } else { 608 // Construct the LLVM IR Dialect function type that the first operand 609 // should match. 610 if (funcType.getNumResults() > 1) 611 return parser.emitError(trailingTypeLoc, 612 "expected function with 0 or 1 result"); 613 614 Type llvmResultType; 615 if (funcType.getNumResults() == 0) { 616 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 617 } else { 618 llvmResultType = funcType.getResult(0); 619 if (!isCompatibleType(llvmResultType)) 620 return parser.emitError(trailingTypeLoc, 621 "expected result to have LLVM type"); 622 } 623 624 SmallVector<Type, 8> argTypes; 625 argTypes.reserve(funcType.getNumInputs()); 626 for (Type ty : funcType.getInputs()) { 627 if (isCompatibleType(ty)) 628 argTypes.push_back(ty); 629 else 630 return parser.emitError(trailingTypeLoc, 631 "expected LLVM types as inputs"); 632 } 633 634 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 635 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 636 637 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 638 639 // Make sure that the first operand (indirect callee) matches the wrapped 640 // LLVM IR function type, and that the types of the other call operands 641 // match the types of the function arguments. 642 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 643 parser.resolveOperands(funcArguments, funcType.getInputs(), 644 parser.getNameLoc(), result.operands)) 645 return failure(); 646 647 result.addTypes(llvmResultType); 648 } 649 result.addSuccessors({normalDest, unwindDest}); 650 result.addOperands(normalOperands); 651 result.addOperands(unwindOperands); 652 653 result.addAttribute( 654 InvokeOp::getOperandSegmentSizeAttr(), 655 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 656 static_cast<int32_t>(normalOperands.size()), 657 static_cast<int32_t>(unwindOperands.size())})); 658 return success(); 659 } 660 661 ///===----------------------------------------------------------------------===// 662 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 663 ///===----------------------------------------------------------------------===// 664 665 static LogicalResult verify(LandingpadOp op) { 666 Value value; 667 if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) { 668 if (!func.personality().hasValue()) 669 return op.emitError( 670 "llvm.landingpad needs to be in a function with a personality"); 671 } 672 673 if (!op.cleanup() && op.getOperands().empty()) 674 return op.emitError("landingpad instruction expects at least one clause or " 675 "cleanup attribute"); 676 677 for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { 678 value = op.getOperand(idx); 679 bool isFilter = value.getType().isa<LLVMArrayType>(); 680 if (isFilter) { 681 // FIXME: Verify filter clauses when arrays are appropriately handled 682 } else { 683 // catch - global addresses only. 684 // Bitcast ops should have global addresses as their args. 685 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 686 if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>()) 687 continue; 688 return op.emitError("constant clauses expected") 689 .attachNote(bcOp.getLoc()) 690 << "global addresses expected as operand to " 691 "bitcast used in clauses for landingpad"; 692 } 693 // NullOp and AddressOfOp allowed 694 if (value.getDefiningOp<NullOp>()) 695 continue; 696 if (value.getDefiningOp<AddressOfOp>()) 697 continue; 698 return op.emitError("clause #") 699 << idx << " is not a known constant - null, addressof, bitcast"; 700 } 701 } 702 return success(); 703 } 704 705 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { 706 p << op.getOperationName() << (op.cleanup() ? " cleanup " : " "); 707 708 // Clauses 709 for (auto value : op.getOperands()) { 710 // Similar to llvm - if clause is an array type then it is filter 711 // clause else catch clause 712 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 713 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 714 << value.getType() << ") "; 715 } 716 717 p.printOptionalAttrDict(op->getAttrs(), {"cleanup"}); 718 719 p << ": " << op.getType(); 720 } 721 722 /// <operation> ::= `llvm.landingpad` `cleanup`? 723 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 724 static ParseResult parseLandingpadOp(OpAsmParser &parser, 725 OperationState &result) { 726 // Check for cleanup 727 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 728 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 729 730 // Parse clauses with types 731 while (succeeded(parser.parseOptionalLParen()) && 732 (succeeded(parser.parseOptionalKeyword("filter")) || 733 succeeded(parser.parseOptionalKeyword("catch")))) { 734 OpAsmParser::OperandType operand; 735 Type ty; 736 if (parser.parseOperand(operand) || parser.parseColon() || 737 parser.parseType(ty) || 738 parser.resolveOperand(operand, ty, result.operands) || 739 parser.parseRParen()) 740 return failure(); 741 } 742 743 Type type; 744 if (parser.parseColon() || parser.parseType(type)) 745 return failure(); 746 747 result.addTypes(type); 748 return success(); 749 } 750 751 //===----------------------------------------------------------------------===// 752 // Verifying/Printing/parsing for LLVM::CallOp. 753 //===----------------------------------------------------------------------===// 754 755 static LogicalResult verify(CallOp &op) { 756 if (op.getNumResults() > 1) 757 return op.emitOpError("must have 0 or 1 result"); 758 759 // Type for the callee, we'll get it differently depending if it is a direct 760 // or indirect call. 761 Type fnType; 762 763 bool isIndirect = false; 764 765 // If this is an indirect call, the callee attribute is missing. 766 Optional<StringRef> calleeName = op.callee(); 767 if (!calleeName) { 768 isIndirect = true; 769 if (!op.getNumOperands()) 770 return op.emitOpError( 771 "must have either a `callee` attribute or at least an operand"); 772 auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>(); 773 if (!ptrType) 774 return op.emitOpError("indirect call expects a pointer as callee: ") 775 << ptrType; 776 fnType = ptrType.getElementType(); 777 } else { 778 Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName); 779 if (!callee) 780 return op.emitOpError() 781 << "'" << *calleeName 782 << "' does not reference a symbol in the current scope"; 783 auto fn = dyn_cast<LLVMFuncOp>(callee); 784 if (!fn) 785 return op.emitOpError() << "'" << *calleeName 786 << "' does not reference a valid LLVM function"; 787 788 fnType = fn.getType(); 789 } 790 791 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 792 if (!funcType) 793 return op.emitOpError("callee does not have a functional type: ") << fnType; 794 795 // Verify that the operand and result types match the callee. 796 797 if (!funcType.isVarArg() && 798 funcType.getNumParams() != (op.getNumOperands() - isIndirect)) 799 return op.emitOpError() 800 << "incorrect number of operands (" 801 << (op.getNumOperands() - isIndirect) 802 << ") for callee (expecting: " << funcType.getNumParams() << ")"; 803 804 if (funcType.getNumParams() > (op.getNumOperands() - isIndirect)) 805 return op.emitOpError() << "incorrect number of operands (" 806 << (op.getNumOperands() - isIndirect) 807 << ") for varargs callee (expecting at least: " 808 << funcType.getNumParams() << ")"; 809 810 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 811 if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 812 return op.emitOpError() << "operand type mismatch for operand " << i 813 << ": " << op.getOperand(i + isIndirect).getType() 814 << " != " << funcType.getParamType(i); 815 816 if (op.getNumResults() && 817 op.getResult(0).getType() != funcType.getReturnType()) 818 return op.emitOpError() 819 << "result type mismatch: " << op.getResult(0).getType() 820 << " != " << funcType.getReturnType(); 821 822 return success(); 823 } 824 825 static void printCallOp(OpAsmPrinter &p, CallOp &op) { 826 auto callee = op.callee(); 827 bool isDirect = callee.hasValue(); 828 829 // Print the direct callee if present as a function attribute, or an indirect 830 // callee (first operand) otherwise. 831 p << op.getOperationName() << ' '; 832 if (isDirect) 833 p.printSymbolName(callee.getValue()); 834 else 835 p << op.getOperand(0); 836 837 auto args = op.getOperands().drop_front(isDirect ? 0 : 1); 838 p << '(' << args << ')'; 839 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"}); 840 841 // Reconstruct the function MLIR function type from operand and result types. 842 p << " : " 843 << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes()); 844 } 845 846 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 847 // attribute-dict? `:` function-type 848 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { 849 SmallVector<OpAsmParser::OperandType, 8> operands; 850 Type type; 851 SymbolRefAttr funcAttr; 852 llvm::SMLoc trailingTypeLoc; 853 854 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 855 // case of an indirect call, there will be 1 operand before `(`. In case of a 856 // direct call, there will be no operands and the parser will stop at the 857 // function identifier without complaining. 858 if (parser.parseOperandList(operands)) 859 return failure(); 860 bool isDirect = operands.empty(); 861 862 // Optionally parse a function identifier. 863 if (isDirect) 864 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 865 return failure(); 866 867 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 868 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 869 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 870 return failure(); 871 872 auto funcType = type.dyn_cast<FunctionType>(); 873 if (!funcType) 874 return parser.emitError(trailingTypeLoc, "expected function type"); 875 if (isDirect) { 876 // Make sure types match. 877 if (parser.resolveOperands(operands, funcType.getInputs(), 878 parser.getNameLoc(), result.operands)) 879 return failure(); 880 result.addTypes(funcType.getResults()); 881 } else { 882 // Construct the LLVM IR Dialect function type that the first operand 883 // should match. 884 if (funcType.getNumResults() > 1) 885 return parser.emitError(trailingTypeLoc, 886 "expected function with 0 or 1 result"); 887 888 Builder &builder = parser.getBuilder(); 889 Type llvmResultType; 890 if (funcType.getNumResults() == 0) { 891 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 892 } else { 893 llvmResultType = funcType.getResult(0); 894 if (!isCompatibleType(llvmResultType)) 895 return parser.emitError(trailingTypeLoc, 896 "expected result to have LLVM type"); 897 } 898 899 SmallVector<Type, 8> argTypes; 900 argTypes.reserve(funcType.getNumInputs()); 901 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 902 auto argType = funcType.getInput(i); 903 if (!isCompatibleType(argType)) 904 return parser.emitError(trailingTypeLoc, 905 "expected LLVM types as inputs"); 906 argTypes.push_back(argType); 907 } 908 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 909 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 910 911 auto funcArguments = 912 ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); 913 914 // Make sure that the first operand (indirect callee) matches the wrapped 915 // LLVM IR function type, and that the types of the other call operands 916 // match the types of the function arguments. 917 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 918 parser.resolveOperands(funcArguments, funcType.getInputs(), 919 parser.getNameLoc(), result.operands)) 920 return failure(); 921 922 result.addTypes(llvmResultType); 923 } 924 925 return success(); 926 } 927 928 //===----------------------------------------------------------------------===// 929 // Printing/parsing for LLVM::ExtractElementOp. 930 //===----------------------------------------------------------------------===// 931 // Expects vector to be of wrapped LLVM vector type and position to be of 932 // wrapped LLVM i32 type. 933 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 934 Value vector, Value position, 935 ArrayRef<NamedAttribute> attrs) { 936 auto vectorType = vector.getType(); 937 auto llvmType = LLVM::getVectorElementType(vectorType); 938 build(b, result, llvmType, vector, position); 939 result.addAttributes(attrs); 940 } 941 942 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { 943 p << op.getOperationName() << ' ' << op.vector() << "[" << op.position() 944 << " : " << op.position().getType() << "]"; 945 p.printOptionalAttrDict(op->getAttrs()); 946 p << " : " << op.vector().getType(); 947 } 948 949 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 950 // attribute-dict? `:` type 951 static ParseResult parseExtractElementOp(OpAsmParser &parser, 952 OperationState &result) { 953 llvm::SMLoc loc; 954 OpAsmParser::OperandType vector, position; 955 Type type, positionType; 956 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 957 parser.parseLSquare() || parser.parseOperand(position) || 958 parser.parseColonType(positionType) || parser.parseRSquare() || 959 parser.parseOptionalAttrDict(result.attributes) || 960 parser.parseColonType(type) || 961 parser.resolveOperand(vector, type, result.operands) || 962 parser.resolveOperand(position, positionType, result.operands)) 963 return failure(); 964 if (!LLVM::isCompatibleVectorType(type)) 965 return parser.emitError( 966 loc, "expected LLVM dialect-compatible vector type for operand #1"); 967 result.addTypes(LLVM::getVectorElementType(type)); 968 return success(); 969 } 970 971 //===----------------------------------------------------------------------===// 972 // Printing/parsing for LLVM::ExtractValueOp. 973 //===----------------------------------------------------------------------===// 974 975 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { 976 p << op.getOperationName() << ' ' << op.container() << op.position(); 977 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 978 p << " : " << op.container().getType(); 979 } 980 981 // Extract the type at `position` in the wrapped LLVM IR aggregate type 982 // `containerType`. Position is an integer array attribute where each value 983 // is a zero-based position of the element in the aggregate type. Return the 984 // resulting type wrapped in MLIR, or nullptr on error. 985 static Type getInsertExtractValueElementType(OpAsmParser &parser, 986 Type containerType, 987 ArrayAttr positionAttr, 988 llvm::SMLoc attributeLoc, 989 llvm::SMLoc typeLoc) { 990 Type llvmType = containerType; 991 if (!isCompatibleType(containerType)) 992 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 993 994 // Infer the element type from the structure type: iteratively step inside the 995 // type by taking the element type, indexed by the position attribute for 996 // structures. Check the position index before accessing, it is supposed to 997 // be in bounds. 998 for (Attribute subAttr : positionAttr) { 999 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1000 if (!positionElementAttr) 1001 return parser.emitError(attributeLoc, 1002 "expected an array of integer literals"), 1003 nullptr; 1004 int position = positionElementAttr.getInt(); 1005 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1006 if (position < 0 || 1007 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1008 return parser.emitError(attributeLoc, "position out of bounds"), 1009 nullptr; 1010 llvmType = arrayType.getElementType(); 1011 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1012 if (position < 0 || 1013 static_cast<unsigned>(position) >= structType.getBody().size()) 1014 return parser.emitError(attributeLoc, "position out of bounds"), 1015 nullptr; 1016 llvmType = structType.getBody()[position]; 1017 } else { 1018 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1019 nullptr; 1020 } 1021 } 1022 return llvmType; 1023 } 1024 1025 // <operation> ::= `llvm.extractvalue` ssa-use 1026 // `[` integer-literal (`,` integer-literal)* `]` 1027 // attribute-dict? `:` type 1028 static ParseResult parseExtractValueOp(OpAsmParser &parser, 1029 OperationState &result) { 1030 OpAsmParser::OperandType container; 1031 Type containerType; 1032 ArrayAttr positionAttr; 1033 llvm::SMLoc attributeLoc, trailingTypeLoc; 1034 1035 if (parser.parseOperand(container) || 1036 parser.getCurrentLocation(&attributeLoc) || 1037 parser.parseAttribute(positionAttr, "position", result.attributes) || 1038 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1039 parser.getCurrentLocation(&trailingTypeLoc) || 1040 parser.parseType(containerType) || 1041 parser.resolveOperand(container, containerType, result.operands)) 1042 return failure(); 1043 1044 auto elementType = getInsertExtractValueElementType( 1045 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1046 if (!elementType) 1047 return failure(); 1048 1049 result.addTypes(elementType); 1050 return success(); 1051 } 1052 1053 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1054 auto insertValueOp = container().getDefiningOp<InsertValueOp>(); 1055 while (insertValueOp) { 1056 if (position() == insertValueOp.position()) 1057 return insertValueOp.value(); 1058 insertValueOp = insertValueOp.container().getDefiningOp<InsertValueOp>(); 1059 } 1060 return {}; 1061 } 1062 1063 //===----------------------------------------------------------------------===// 1064 // Printing/parsing for LLVM::InsertElementOp. 1065 //===----------------------------------------------------------------------===// 1066 1067 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { 1068 p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "[" 1069 << op.position() << " : " << op.position().getType() << "]"; 1070 p.printOptionalAttrDict(op->getAttrs()); 1071 p << " : " << op.vector().getType(); 1072 } 1073 1074 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1075 // attribute-dict? `:` type 1076 static ParseResult parseInsertElementOp(OpAsmParser &parser, 1077 OperationState &result) { 1078 llvm::SMLoc loc; 1079 OpAsmParser::OperandType vector, value, position; 1080 Type vectorType, positionType; 1081 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1082 parser.parseComma() || parser.parseOperand(vector) || 1083 parser.parseLSquare() || parser.parseOperand(position) || 1084 parser.parseColonType(positionType) || parser.parseRSquare() || 1085 parser.parseOptionalAttrDict(result.attributes) || 1086 parser.parseColonType(vectorType)) 1087 return failure(); 1088 1089 if (!LLVM::isCompatibleVectorType(vectorType)) 1090 return parser.emitError( 1091 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1092 Type valueType = LLVM::getVectorElementType(vectorType); 1093 if (!valueType) 1094 return failure(); 1095 1096 if (parser.resolveOperand(vector, vectorType, result.operands) || 1097 parser.resolveOperand(value, valueType, result.operands) || 1098 parser.resolveOperand(position, positionType, result.operands)) 1099 return failure(); 1100 1101 result.addTypes(vectorType); 1102 return success(); 1103 } 1104 1105 //===----------------------------------------------------------------------===// 1106 // Printing/parsing for LLVM::InsertValueOp. 1107 //===----------------------------------------------------------------------===// 1108 1109 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { 1110 p << op.getOperationName() << ' ' << op.value() << ", " << op.container() 1111 << op.position(); 1112 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 1113 p << " : " << op.container().getType(); 1114 } 1115 1116 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1117 // `[` integer-literal (`,` integer-literal)* `]` 1118 // attribute-dict? `:` type 1119 static ParseResult parseInsertValueOp(OpAsmParser &parser, 1120 OperationState &result) { 1121 OpAsmParser::OperandType container, value; 1122 Type containerType; 1123 ArrayAttr positionAttr; 1124 llvm::SMLoc attributeLoc, trailingTypeLoc; 1125 1126 if (parser.parseOperand(value) || parser.parseComma() || 1127 parser.parseOperand(container) || 1128 parser.getCurrentLocation(&attributeLoc) || 1129 parser.parseAttribute(positionAttr, "position", result.attributes) || 1130 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1131 parser.getCurrentLocation(&trailingTypeLoc) || 1132 parser.parseType(containerType)) 1133 return failure(); 1134 1135 auto valueType = getInsertExtractValueElementType( 1136 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1137 if (!valueType) 1138 return failure(); 1139 1140 if (parser.resolveOperand(container, containerType, result.operands) || 1141 parser.resolveOperand(value, valueType, result.operands)) 1142 return failure(); 1143 1144 result.addTypes(containerType); 1145 return success(); 1146 } 1147 1148 //===----------------------------------------------------------------------===// 1149 // Printing, parsing and verification for LLVM::ReturnOp. 1150 //===----------------------------------------------------------------------===// 1151 1152 static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { 1153 p << op.getOperationName(); 1154 p.printOptionalAttrDict(op->getAttrs()); 1155 assert(op.getNumOperands() <= 1); 1156 1157 if (op.getNumOperands() == 0) 1158 return; 1159 1160 p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); 1161 } 1162 1163 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` 1164 // type-list-no-parens 1165 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 1166 SmallVector<OpAsmParser::OperandType, 1> operands; 1167 Type type; 1168 1169 if (parser.parseOperandList(operands) || 1170 parser.parseOptionalAttrDict(result.attributes)) 1171 return failure(); 1172 if (operands.empty()) 1173 return success(); 1174 1175 if (parser.parseColonType(type) || 1176 parser.resolveOperand(operands[0], type, result.operands)) 1177 return failure(); 1178 return success(); 1179 } 1180 1181 static LogicalResult verify(ReturnOp op) { 1182 if (op->getNumOperands() > 1) 1183 return op->emitOpError("expected at most 1 operand"); 1184 1185 if (auto parent = op->getParentOfType<LLVMFuncOp>()) { 1186 Type expectedType = parent.getType().getReturnType(); 1187 if (expectedType.isa<LLVMVoidType>()) { 1188 if (op->getNumOperands() == 0) 1189 return success(); 1190 InFlightDiagnostic diag = op->emitOpError("expected no operands"); 1191 diag.attachNote(parent->getLoc()) << "when returning from function"; 1192 return diag; 1193 } 1194 if (op->getNumOperands() == 0) { 1195 if (expectedType.isa<LLVMVoidType>()) 1196 return success(); 1197 InFlightDiagnostic diag = op->emitOpError("expected 1 operand"); 1198 diag.attachNote(parent->getLoc()) << "when returning from function"; 1199 return diag; 1200 } 1201 if (expectedType != op->getOperand(0).getType()) { 1202 InFlightDiagnostic diag = op->emitOpError("mismatching result types"); 1203 diag.attachNote(parent->getLoc()) << "when returning from function"; 1204 return diag; 1205 } 1206 } 1207 return success(); 1208 } 1209 1210 //===----------------------------------------------------------------------===// 1211 // Verifier for LLVM::AddressOfOp. 1212 //===----------------------------------------------------------------------===// 1213 1214 template <typename OpTy> 1215 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1216 Operation *module = parent; 1217 while (module && !satisfiesLLVMModule(module)) 1218 module = module->getParentOp(); 1219 assert(module && "unexpected operation outside of a module"); 1220 return dyn_cast_or_null<OpTy>( 1221 mlir::SymbolTable::lookupSymbolIn(module, name)); 1222 } 1223 1224 GlobalOp AddressOfOp::getGlobal() { 1225 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1226 global_name()); 1227 } 1228 1229 LLVMFuncOp AddressOfOp::getFunction() { 1230 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1231 global_name()); 1232 } 1233 1234 static LogicalResult verify(AddressOfOp op) { 1235 auto global = op.getGlobal(); 1236 auto function = op.getFunction(); 1237 if (!global && !function) 1238 return op.emitOpError( 1239 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1240 1241 if (global && 1242 LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) != 1243 op.getResult().getType()) 1244 return op.emitOpError( 1245 "the type must be a pointer to the type of the referenced global"); 1246 1247 if (function && LLVM::LLVMPointerType::get(function.getType()) != 1248 op.getResult().getType()) 1249 return op.emitOpError( 1250 "the type must be a pointer to the type of the referenced function"); 1251 1252 return success(); 1253 } 1254 1255 //===----------------------------------------------------------------------===// 1256 // Builder, printer and verifier for LLVM::GlobalOp. 1257 //===----------------------------------------------------------------------===// 1258 1259 /// Returns the name used for the linkage attribute. This *must* correspond to 1260 /// the name of the attribute in ODS. 1261 static StringRef getLinkageAttrName() { return "linkage"; } 1262 1263 /// Returns the name used for the unnamed_addr attribute. This *must* correspond 1264 /// to the name of the attribute in ODS. 1265 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; } 1266 1267 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1268 bool isConstant, Linkage linkage, StringRef name, 1269 Attribute value, uint64_t alignment, unsigned addrSpace, 1270 bool dsoLocal, ArrayRef<NamedAttribute> attrs) { 1271 result.addAttribute(SymbolTable::getSymbolAttrName(), 1272 builder.getStringAttr(name)); 1273 result.addAttribute("type", TypeAttr::get(type)); 1274 if (isConstant) 1275 result.addAttribute("constant", builder.getUnitAttr()); 1276 if (value) 1277 result.addAttribute("value", value); 1278 if (dsoLocal) 1279 result.addAttribute("dso_local", builder.getUnitAttr()); 1280 1281 // Only add an alignment attribute if the "alignment" input 1282 // is different from 0. The value must also be a power of two, but 1283 // this is tested in GlobalOp::verify, not here. 1284 if (alignment != 0) 1285 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 1286 1287 result.addAttribute(getLinkageAttrName(), 1288 builder.getI64IntegerAttr(static_cast<int64_t>(linkage))); 1289 if (addrSpace != 0) 1290 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); 1291 result.attributes.append(attrs.begin(), attrs.end()); 1292 result.addRegion(); 1293 } 1294 1295 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { 1296 p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' '; 1297 if (op.unnamed_addr()) 1298 p << stringifyUnnamedAddr(*op.unnamed_addr()) << ' '; 1299 if (op.constant()) 1300 p << "constant "; 1301 p.printSymbolName(op.sym_name()); 1302 p << '('; 1303 if (auto value = op.getValueOrNull()) 1304 p.printAttribute(value); 1305 p << ')'; 1306 // Note that the alignment attribute is printed using the 1307 // default syntax here, even though it is an inherent attribute 1308 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1309 p.printOptionalAttrDict(op->getAttrs(), 1310 {SymbolTable::getSymbolAttrName(), "type", "constant", 1311 "value", getLinkageAttrName(), 1312 getUnnamedAddrAttrName()}); 1313 1314 // Print the trailing type unless it's a string global. 1315 if (op.getValueOrNull().dyn_cast_or_null<StringAttr>()) 1316 return; 1317 p << " : " << op.type(); 1318 1319 Region &initializer = op.getInitializerRegion(); 1320 if (!initializer.empty()) 1321 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1322 } 1323 1324 //===----------------------------------------------------------------------===// 1325 // Verifier for LLVM::DialectCastOp. 1326 //===----------------------------------------------------------------------===// 1327 1328 /// Checks if `llvmType` is dialect cast-compatible with `index` type. Does not 1329 /// report the error, the user is expected to produce an appropriate message. 1330 // TODO: make the size depend on data layout rather than on the conversion 1331 // pass option, and pull that information here. 1332 static LogicalResult verifyCastWithIndex(Type llvmType) { 1333 return success(llvmType.isa<IntegerType>()); 1334 } 1335 1336 /// Checks if `llvmType` is dialect cast-compatible with built-in `type` and 1337 /// reports errors to the location of `op`. `isElement` indicates whether the 1338 /// verification is performed for types that are element types inside a 1339 /// container; we don't want casts from X to X at the top level, but c1<X> to 1340 /// c2<X> may be fine. 1341 static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type, 1342 bool isElement = false) { 1343 // Equal element types are directly compatible. 1344 if (isElement && llvmType == type) 1345 return success(); 1346 1347 // Index is compatible with any integer. 1348 if (type.isIndex()) { 1349 if (succeeded(verifyCastWithIndex(llvmType))) 1350 return success(); 1351 1352 return op.emitOpError("invalid cast between index and non-integer type"); 1353 } 1354 1355 if (type.isa<IntegerType>()) { 1356 auto llvmIntegerType = llvmType.dyn_cast<IntegerType>(); 1357 if (!llvmIntegerType) 1358 return op->emitOpError("invalid cast between integer and non-integer"); 1359 if (llvmIntegerType.getWidth() != type.getIntOrFloatBitWidth()) 1360 return op.emitOpError("invalid cast changing integer width"); 1361 return success(); 1362 } 1363 1364 // Vectors are compatible if they are 1D non-scalable, and their element types 1365 // are compatible. nD vectors are compatible with (n-1)D arrays containing 1D 1366 // vector. 1367 if (auto vectorType = type.dyn_cast<VectorType>()) { 1368 if (vectorType == llvmType && !isElement) 1369 return op.emitOpError("vector types should not be casted"); 1370 1371 if (vectorType.getRank() == 1) { 1372 auto llvmVectorType = llvmType.dyn_cast<VectorType>(); 1373 if (!llvmVectorType || llvmVectorType.getRank() != 1) 1374 return op.emitOpError("invalid cast for vector types"); 1375 1376 return verifyCast(op, llvmVectorType.getElementType(), 1377 vectorType.getElementType(), /*isElement=*/true); 1378 } 1379 1380 auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>(); 1381 if (!arrayType || 1382 arrayType.getNumElements() != vectorType.getShape().front()) 1383 return op.emitOpError("invalid cast for vector, expected array"); 1384 return verifyCast(op, arrayType.getElementType(), 1385 VectorType::get(vectorType.getShape().drop_front(), 1386 vectorType.getElementType()), 1387 /*isElement=*/true); 1388 } 1389 1390 if (auto memrefType = type.dyn_cast<MemRefType>()) { 1391 // Bare pointer convention: statically-shaped memref is compatible with an 1392 // LLVM pointer to the element type. 1393 if (auto ptrType = llvmType.dyn_cast<LLVMPointerType>()) { 1394 if (!memrefType.hasStaticShape()) 1395 return op->emitOpError( 1396 "unexpected bare pointer for dynamically shaped memref"); 1397 if (memrefType.getMemorySpaceAsInt() != ptrType.getAddressSpace()) 1398 return op->emitError("invalid conversion between memref and pointer in " 1399 "different memory spaces"); 1400 1401 return verifyCast(op, ptrType.getElementType(), 1402 memrefType.getElementType(), /*isElement=*/true); 1403 } 1404 1405 // Otherwise, memrefs are convertible to a descriptor, which is a structure 1406 // type. 1407 auto structType = llvmType.dyn_cast<LLVMStructType>(); 1408 if (!structType) 1409 return op->emitOpError("invalid cast between a memref and a type other " 1410 "than pointer or memref descriptor"); 1411 1412 unsigned expectedNumElements = memrefType.getRank() == 0 ? 3 : 5; 1413 if (structType.getBody().size() != expectedNumElements) { 1414 return op->emitOpError() << "expected memref descriptor with " 1415 << expectedNumElements << " elements"; 1416 } 1417 1418 // The first two elements are pointers to the element type. 1419 auto allocatedPtr = structType.getBody()[0].dyn_cast<LLVMPointerType>(); 1420 if (!allocatedPtr || 1421 allocatedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt()) 1422 return op->emitOpError("expected first element of a memref descriptor to " 1423 "be a pointer in the address space of the memref"); 1424 if (failed(verifyCast(op, allocatedPtr.getElementType(), 1425 memrefType.getElementType(), /*isElement=*/true))) 1426 return failure(); 1427 1428 auto alignedPtr = structType.getBody()[1].dyn_cast<LLVMPointerType>(); 1429 if (!alignedPtr || 1430 alignedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt()) 1431 return op->emitOpError( 1432 "expected second element of a memref descriptor to " 1433 "be a pointer in the address space of the memref"); 1434 if (failed(verifyCast(op, alignedPtr.getElementType(), 1435 memrefType.getElementType(), /*isElement=*/true))) 1436 return failure(); 1437 1438 // The second element (offset) is an equivalent of index. 1439 if (failed(verifyCastWithIndex(structType.getBody()[2]))) 1440 return op->emitOpError("expected third element of a memref descriptor to " 1441 "be index-compatible integers"); 1442 1443 // 0D memrefs don't have sizes/strides. 1444 if (memrefType.getRank() == 0) 1445 return success(); 1446 1447 // Sizes and strides are rank-sized arrays of `index` equivalents. 1448 auto sizes = structType.getBody()[3].dyn_cast<LLVMArrayType>(); 1449 if (!sizes || failed(verifyCastWithIndex(sizes.getElementType())) || 1450 sizes.getNumElements() != memrefType.getRank()) 1451 return op->emitOpError( 1452 "expected fourth element of a memref descriptor " 1453 "to be an array of <rank> index-compatible integers"); 1454 1455 auto strides = structType.getBody()[4].dyn_cast<LLVMArrayType>(); 1456 if (!strides || failed(verifyCastWithIndex(strides.getElementType())) || 1457 strides.getNumElements() != memrefType.getRank()) 1458 return op->emitOpError( 1459 "expected fifth element of a memref descriptor " 1460 "to be an array of <rank> index-compatible integers"); 1461 1462 return success(); 1463 } 1464 1465 // Unranked memrefs are compatible with their descriptors. 1466 if (auto unrankedMemrefType = type.dyn_cast<UnrankedMemRefType>()) { 1467 auto structType = llvmType.dyn_cast<LLVMStructType>(); 1468 if (!structType || structType.getBody().size() != 2) 1469 return op->emitOpError( 1470 "expected descriptor to be a struct with two elements"); 1471 1472 if (failed(verifyCastWithIndex(structType.getBody()[0]))) 1473 return op->emitOpError("expected first element of a memref descriptor to " 1474 "be an index-compatible integer"); 1475 1476 auto ptrType = structType.getBody()[1].dyn_cast<LLVMPointerType>(); 1477 auto ptrElementType = 1478 ptrType ? ptrType.getElementType().dyn_cast<IntegerType>() : nullptr; 1479 if (!ptrElementType || ptrElementType.getWidth() != 8) 1480 return op->emitOpError("expected second element of a memref descriptor " 1481 "to be an !llvm.ptr<i8>"); 1482 1483 return success(); 1484 } 1485 1486 // Complex types are compatible with the two-element structs. 1487 if (auto complexType = type.dyn_cast<ComplexType>()) { 1488 auto structType = llvmType.dyn_cast<LLVMStructType>(); 1489 if (!structType || structType.getBody().size() != 2 || 1490 structType.getBody()[0] != structType.getBody()[1] || 1491 structType.getBody()[0] != complexType.getElementType()) 1492 return op->emitOpError("expected 'complex' to map to two-element struct " 1493 "with identical element types"); 1494 return success(); 1495 } 1496 1497 // Everything else is not supported. 1498 return op->emitError("unsupported cast"); 1499 } 1500 1501 static LogicalResult verify(DialectCastOp op) { 1502 if (isCompatibleType(op.getType())) 1503 return verifyCast(op, op.getType(), op.in().getType()); 1504 1505 if (!isCompatibleType(op.in().getType())) 1506 return op->emitOpError("expected one LLVM type and one built-in type"); 1507 1508 return verifyCast(op, op.in().getType(), op.getType()); 1509 } 1510 1511 // Parses one of the keywords provided in the list `keywords` and returns the 1512 // position of the parsed keyword in the list. If none of the keywords from the 1513 // list is parsed, returns -1. 1514 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1515 ArrayRef<StringRef> keywords) { 1516 for (auto en : llvm::enumerate(keywords)) { 1517 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1518 return en.index(); 1519 } 1520 return -1; 1521 } 1522 1523 namespace { 1524 template <typename Ty> 1525 struct EnumTraits {}; 1526 1527 #define REGISTER_ENUM_TYPE(Ty) \ 1528 template <> \ 1529 struct EnumTraits<Ty> { \ 1530 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1531 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1532 } 1533 1534 REGISTER_ENUM_TYPE(Linkage); 1535 REGISTER_ENUM_TYPE(UnnamedAddr); 1536 } // end namespace 1537 1538 template <typename EnumTy> 1539 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser, 1540 OperationState &result, 1541 StringRef name) { 1542 SmallVector<StringRef, 10> names; 1543 for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i) 1544 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1545 1546 int index = parseOptionalKeywordAlternative(parser, names); 1547 if (index == -1) 1548 return failure(); 1549 result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index)); 1550 return success(); 1551 } 1552 1553 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1554 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1555 // align ::= `align` `=` UINT64 1556 // 1557 // The type can be omitted for string attributes, in which case it will be 1558 // inferred from the value of the string as [strlen(value) x i8]. 1559 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 1560 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, 1561 getLinkageAttrName()))) 1562 result.addAttribute(getLinkageAttrName(), 1563 parser.getBuilder().getI64IntegerAttr( 1564 static_cast<int64_t>(LLVM::Linkage::External))); 1565 1566 if (failed(parseOptionalLLVMKeyword<UnnamedAddr>(parser, result, 1567 getUnnamedAddrAttrName()))) 1568 result.addAttribute(getUnnamedAddrAttrName(), 1569 parser.getBuilder().getI64IntegerAttr( 1570 static_cast<int64_t>(LLVM::UnnamedAddr::None))); 1571 1572 if (succeeded(parser.parseOptionalKeyword("constant"))) 1573 result.addAttribute("constant", parser.getBuilder().getUnitAttr()); 1574 1575 StringAttr name; 1576 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), 1577 result.attributes) || 1578 parser.parseLParen()) 1579 return failure(); 1580 1581 Attribute value; 1582 if (parser.parseOptionalRParen()) { 1583 if (parser.parseAttribute(value, "value", result.attributes) || 1584 parser.parseRParen()) 1585 return failure(); 1586 } 1587 1588 SmallVector<Type, 1> types; 1589 if (parser.parseOptionalAttrDict(result.attributes) || 1590 parser.parseOptionalColonTypeList(types)) 1591 return failure(); 1592 1593 if (types.size() > 1) 1594 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1595 1596 Region &initRegion = *result.addRegion(); 1597 if (types.empty()) { 1598 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1599 MLIRContext *context = parser.getBuilder().getContext(); 1600 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1601 strAttr.getValue().size()); 1602 types.push_back(arrayType); 1603 } else { 1604 return parser.emitError(parser.getNameLoc(), 1605 "type can only be omitted for string globals"); 1606 } 1607 } else { 1608 OptionalParseResult parseResult = 1609 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1610 /*argTypes=*/{}); 1611 if (parseResult.hasValue() && failed(*parseResult)) 1612 return failure(); 1613 } 1614 1615 result.addAttribute("type", TypeAttr::get(types[0])); 1616 return success(); 1617 } 1618 1619 static bool isZeroAttribute(Attribute value) { 1620 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1621 return intValue.getValue().isNullValue(); 1622 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1623 return fpValue.getValue().isZero(); 1624 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1625 return isZeroAttribute(splatValue.getSplatValue()); 1626 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1627 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1628 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1629 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1630 return false; 1631 } 1632 1633 static LogicalResult verify(GlobalOp op) { 1634 if (!LLVMPointerType::isValidElementType(op.getType())) 1635 return op.emitOpError( 1636 "expects type to be a valid element type for an LLVM pointer"); 1637 if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp())) 1638 return op.emitOpError("must appear at the module level"); 1639 1640 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1641 auto type = op.getType().dyn_cast<LLVMArrayType>(); 1642 IntegerType elementType = 1643 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1644 if (!elementType || elementType.getWidth() != 8 || 1645 type.getNumElements() != strAttr.getValue().size()) 1646 return op.emitOpError( 1647 "requires an i8 array type of the length equal to that of the string " 1648 "attribute"); 1649 } 1650 1651 if (Block *b = op.getInitializerBlock()) { 1652 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1653 if (ret.operand_type_begin() == ret.operand_type_end()) 1654 return op.emitOpError("initializer region cannot return void"); 1655 if (*ret.operand_type_begin() != op.getType()) 1656 return op.emitOpError("initializer region type ") 1657 << *ret.operand_type_begin() << " does not match global type " 1658 << op.getType(); 1659 1660 if (op.getValueOrNull()) 1661 return op.emitOpError("cannot have both initializer value and region"); 1662 } 1663 1664 if (op.linkage() == Linkage::Common) { 1665 if (Attribute value = op.getValueOrNull()) { 1666 if (!isZeroAttribute(value)) { 1667 return op.emitOpError() 1668 << "expected zero value for '" 1669 << stringifyLinkage(Linkage::Common) << "' linkage"; 1670 } 1671 } 1672 } 1673 1674 if (op.linkage() == Linkage::Appending) { 1675 if (!op.getType().isa<LLVMArrayType>()) { 1676 return op.emitOpError() 1677 << "expected array type for '" 1678 << stringifyLinkage(Linkage::Appending) << "' linkage"; 1679 } 1680 } 1681 1682 Optional<uint64_t> alignAttr = op.alignment(); 1683 if (alignAttr.hasValue()) { 1684 uint64_t value = alignAttr.getValue(); 1685 if (!llvm::isPowerOf2_64(value)) 1686 return op->emitError() << "alignment attribute is not a power of 2"; 1687 } 1688 1689 return success(); 1690 } 1691 1692 //===----------------------------------------------------------------------===// 1693 // Printing/parsing for LLVM::ShuffleVectorOp. 1694 //===----------------------------------------------------------------------===// 1695 // Expects vector to be of wrapped LLVM vector type and position to be of 1696 // wrapped LLVM i32 type. 1697 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 1698 Value v1, Value v2, ArrayAttr mask, 1699 ArrayRef<NamedAttribute> attrs) { 1700 auto containerType = v1.getType(); 1701 auto vType = LLVM::getFixedVectorType( 1702 LLVM::getVectorElementType(containerType), mask.size()); 1703 build(b, result, vType, v1, v2, mask); 1704 result.addAttributes(attrs); 1705 } 1706 1707 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { 1708 p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " " 1709 << op.mask(); 1710 p.printOptionalAttrDict(op->getAttrs(), {"mask"}); 1711 p << " : " << op.v1().getType() << ", " << op.v2().getType(); 1712 } 1713 1714 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 1715 // `[` integer-literal (`,` integer-literal)* `]` 1716 // attribute-dict? `:` type 1717 static ParseResult parseShuffleVectorOp(OpAsmParser &parser, 1718 OperationState &result) { 1719 llvm::SMLoc loc; 1720 OpAsmParser::OperandType v1, v2; 1721 ArrayAttr maskAttr; 1722 Type typeV1, typeV2; 1723 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 1724 parser.parseComma() || parser.parseOperand(v2) || 1725 parser.parseAttribute(maskAttr, "mask", result.attributes) || 1726 parser.parseOptionalAttrDict(result.attributes) || 1727 parser.parseColonType(typeV1) || parser.parseComma() || 1728 parser.parseType(typeV2) || 1729 parser.resolveOperand(v1, typeV1, result.operands) || 1730 parser.resolveOperand(v2, typeV2, result.operands)) 1731 return failure(); 1732 if (!LLVM::isCompatibleVectorType(typeV1)) 1733 return parser.emitError( 1734 loc, "expected LLVM IR dialect vector type for operand #1"); 1735 auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1), 1736 maskAttr.size()); 1737 result.addTypes(vType); 1738 return success(); 1739 } 1740 1741 //===----------------------------------------------------------------------===// 1742 // Implementations for LLVM::LLVMFuncOp. 1743 //===----------------------------------------------------------------------===// 1744 1745 // Add the entry block to the function. 1746 Block *LLVMFuncOp::addEntryBlock() { 1747 assert(empty() && "function already has an entry block"); 1748 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 1749 1750 auto *entry = new Block; 1751 push_back(entry); 1752 1753 LLVMFunctionType type = getType(); 1754 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 1755 entry->addArgument(type.getParamType(i)); 1756 return entry; 1757 } 1758 1759 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 1760 StringRef name, Type type, LLVM::Linkage linkage, 1761 bool dsoLocal, ArrayRef<NamedAttribute> attrs, 1762 ArrayRef<DictionaryAttr> argAttrs) { 1763 result.addRegion(); 1764 result.addAttribute(SymbolTable::getSymbolAttrName(), 1765 builder.getStringAttr(name)); 1766 result.addAttribute("type", TypeAttr::get(type)); 1767 result.addAttribute(getLinkageAttrName(), 1768 builder.getI64IntegerAttr(static_cast<int64_t>(linkage))); 1769 result.attributes.append(attrs.begin(), attrs.end()); 1770 if (dsoLocal) 1771 result.addAttribute("dso_local", builder.getUnitAttr()); 1772 if (argAttrs.empty()) 1773 return; 1774 1775 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 1776 "expected as many argument attribute lists as arguments"); 1777 function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, 1778 /*resultAttrs=*/llvm::None); 1779 } 1780 1781 // Builds an LLVM function type from the given lists of input and output types. 1782 // Returns a null type if any of the types provided are non-LLVM types, or if 1783 // there is more than one output type. 1784 static Type 1785 buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, 1786 ArrayRef<Type> inputs, ArrayRef<Type> outputs, 1787 function_like_impl::VariadicFlag variadicFlag) { 1788 Builder &b = parser.getBuilder(); 1789 if (outputs.size() > 1) { 1790 parser.emitError(loc, "failed to construct function type: expected zero or " 1791 "one function result"); 1792 return {}; 1793 } 1794 1795 // Convert inputs to LLVM types, exit early on error. 1796 SmallVector<Type, 4> llvmInputs; 1797 for (auto t : inputs) { 1798 if (!isCompatibleType(t)) { 1799 parser.emitError(loc, "failed to construct function type: expected LLVM " 1800 "type for function arguments"); 1801 return {}; 1802 } 1803 llvmInputs.push_back(t); 1804 } 1805 1806 // No output is denoted as "void" in LLVM type system. 1807 Type llvmOutput = 1808 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 1809 if (!isCompatibleType(llvmOutput)) { 1810 parser.emitError(loc, "failed to construct function type: expected LLVM " 1811 "type for function results") 1812 << llvmOutput; 1813 return {}; 1814 } 1815 return LLVMFunctionType::get(llvmOutput, llvmInputs, 1816 variadicFlag.isVariadic()); 1817 } 1818 1819 // Parses an LLVM function. 1820 // 1821 // operation ::= `llvm.func` linkage? function-signature function-attributes? 1822 // function-body 1823 // 1824 static ParseResult parseLLVMFuncOp(OpAsmParser &parser, 1825 OperationState &result) { 1826 // Default to external linkage if no keyword is provided. 1827 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result, 1828 getLinkageAttrName()))) 1829 result.addAttribute(getLinkageAttrName(), 1830 parser.getBuilder().getI64IntegerAttr( 1831 static_cast<int64_t>(LLVM::Linkage::External))); 1832 1833 StringAttr nameAttr; 1834 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 1835 SmallVector<NamedAttrList, 1> argAttrs; 1836 SmallVector<NamedAttrList, 1> resultAttrs; 1837 SmallVector<Type, 8> argTypes; 1838 SmallVector<Type, 4> resultTypes; 1839 bool isVariadic; 1840 1841 auto signatureLocation = parser.getCurrentLocation(); 1842 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1843 result.attributes) || 1844 function_like_impl::parseFunctionSignature( 1845 parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, 1846 isVariadic, resultTypes, resultAttrs)) 1847 return failure(); 1848 1849 auto type = 1850 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 1851 function_like_impl::VariadicFlag(isVariadic)); 1852 if (!type) 1853 return failure(); 1854 result.addAttribute(function_like_impl::getTypeAttrName(), 1855 TypeAttr::get(type)); 1856 1857 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 1858 return failure(); 1859 function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result, 1860 argAttrs, resultAttrs); 1861 1862 auto *body = result.addRegion(); 1863 OptionalParseResult parseResult = parser.parseOptionalRegion( 1864 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 1865 return failure(parseResult.hasValue() && failed(*parseResult)); 1866 } 1867 1868 // Print the LLVMFuncOp. Collects argument and result types and passes them to 1869 // helper functions. Drops "void" result since it cannot be parsed back. Skips 1870 // the external linkage since it is the default value. 1871 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { 1872 p << op.getOperationName() << ' '; 1873 if (op.linkage() != LLVM::Linkage::External) 1874 p << stringifyLinkage(op.linkage()) << ' '; 1875 p.printSymbolName(op.getName()); 1876 1877 LLVMFunctionType fnType = op.getType(); 1878 SmallVector<Type, 8> argTypes; 1879 SmallVector<Type, 1> resTypes; 1880 argTypes.reserve(fnType.getNumParams()); 1881 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 1882 argTypes.push_back(fnType.getParamType(i)); 1883 1884 Type returnType = fnType.getReturnType(); 1885 if (!returnType.isa<LLVMVoidType>()) 1886 resTypes.push_back(returnType); 1887 1888 function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), 1889 resTypes); 1890 function_like_impl::printFunctionAttributes( 1891 p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); 1892 1893 // Print the body if this is not an external function. 1894 Region &body = op.body(); 1895 if (!body.empty()) 1896 p.printRegion(body, /*printEntryBlockArgs=*/false, 1897 /*printBlockTerminators=*/true); 1898 } 1899 1900 // Hook for OpTrait::FunctionLike, called after verifying that the 'type' 1901 // attribute is present. This can check for preconditions of the 1902 // getNumArguments hook not failing. 1903 LogicalResult LLVMFuncOp::verifyType() { 1904 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>(); 1905 if (!llvmType) 1906 return emitOpError("requires '" + getTypeAttrName() + 1907 "' attribute of wrapped LLVM function type"); 1908 1909 return success(); 1910 } 1911 1912 // Hook for OpTrait::FunctionLike, returns the number of function arguments. 1913 // Depends on the type attribute being correct as checked by verifyType 1914 unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); } 1915 1916 // Hook for OpTrait::FunctionLike, returns the number of function results. 1917 // Depends on the type attribute being correct as checked by verifyType 1918 unsigned LLVMFuncOp::getNumFuncResults() { 1919 // We model LLVM functions that return void as having zero results, 1920 // and all others as having one result. 1921 // If we modeled a void return as one result, then it would be possible to 1922 // attach an MLIR result attribute to it, and it isn't clear what semantics we 1923 // would assign to that. 1924 if (getType().getReturnType().isa<LLVMVoidType>()) 1925 return 0; 1926 return 1; 1927 } 1928 1929 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 1930 // - functions don't have 'common' linkage 1931 // - external functions have 'external' or 'extern_weak' linkage; 1932 // - vararg is (currently) only supported for external functions; 1933 // - entry block arguments are of LLVM types and match the function signature. 1934 static LogicalResult verify(LLVMFuncOp op) { 1935 if (op.linkage() == LLVM::Linkage::Common) 1936 return op.emitOpError() 1937 << "functions cannot have '" 1938 << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; 1939 1940 if (op.isExternal()) { 1941 if (op.linkage() != LLVM::Linkage::External && 1942 op.linkage() != LLVM::Linkage::ExternWeak) 1943 return op.emitOpError() 1944 << "external functions must have '" 1945 << stringifyLinkage(LLVM::Linkage::External) << "' or '" 1946 << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; 1947 return success(); 1948 } 1949 1950 if (op.isVarArg()) 1951 return op.emitOpError("only external functions can be variadic"); 1952 1953 unsigned numArguments = op.getType().getNumParams(); 1954 Block &entryBlock = op.front(); 1955 for (unsigned i = 0; i < numArguments; ++i) { 1956 Type argType = entryBlock.getArgument(i).getType(); 1957 if (!isCompatibleType(argType)) 1958 return op.emitOpError("entry block argument #") 1959 << i << " is not of LLVM type"; 1960 if (op.getType().getParamType(i) != argType) 1961 return op.emitOpError("the type of entry block argument #") 1962 << i << " does not match the function signature"; 1963 } 1964 1965 return success(); 1966 } 1967 1968 //===----------------------------------------------------------------------===// 1969 // Verification for LLVM::ConstantOp. 1970 //===----------------------------------------------------------------------===// 1971 1972 static LogicalResult verify(LLVM::ConstantOp op) { 1973 if (StringAttr sAttr = op.value().dyn_cast<StringAttr>()) { 1974 auto arrayType = op.getType().dyn_cast<LLVMArrayType>(); 1975 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 1976 !arrayType.getElementType().isInteger(8)) { 1977 return op->emitOpError() 1978 << "expected array type of " << sAttr.getValue().size() 1979 << " i8 elements for the string constant"; 1980 } 1981 return success(); 1982 } 1983 if (auto structType = op.getType().dyn_cast<LLVMStructType>()) { 1984 if (structType.getBody().size() != 2 || 1985 structType.getBody()[0] != structType.getBody()[1]) { 1986 return op.emitError() << "expected struct type with two elements of the " 1987 "same type, the type of a complex constant"; 1988 } 1989 1990 auto arrayAttr = op.value().dyn_cast<ArrayAttr>(); 1991 if (!arrayAttr || arrayAttr.size() != 2 || 1992 arrayAttr[0].getType() != arrayAttr[1].getType()) { 1993 return op.emitOpError() << "expected array attribute with two elements, " 1994 "representing a complex constant"; 1995 } 1996 1997 Type elementType = structType.getBody()[0]; 1998 if (!elementType 1999 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2000 return op.emitError() 2001 << "expected struct element types to be floating point type or " 2002 "integer type"; 2003 } 2004 return success(); 2005 } 2006 if (!op.value().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2007 return op.emitOpError() 2008 << "only supports integer, float, string or elements attributes"; 2009 return success(); 2010 } 2011 2012 //===----------------------------------------------------------------------===// 2013 // Utility functions for parsing atomic ops 2014 //===----------------------------------------------------------------------===// 2015 2016 // Helper function to parse a keyword into the specified attribute named by 2017 // `attrName`. The keyword must match one of the string values defined by the 2018 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2019 // state. 2020 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2021 StringRef attrName) { 2022 llvm::SMLoc loc; 2023 StringRef keyword; 2024 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2025 return failure(); 2026 2027 // Replace the keyword `keyword` with an integer attribute. 2028 auto kind = symbolizeAtomicBinOp(keyword); 2029 if (!kind) { 2030 return parser.emitError(loc) 2031 << "'" << keyword << "' is an incorrect value of the '" << attrName 2032 << "' attribute"; 2033 } 2034 2035 auto value = static_cast<int64_t>(kind.getValue()); 2036 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2037 result.addAttribute(attrName, attr); 2038 2039 return success(); 2040 } 2041 2042 // Helper function to parse a keyword into the specified attribute named by 2043 // `attrName`. The keyword must match one of the string values defined by the 2044 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2045 // state. 2046 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2047 OperationState &result, 2048 StringRef attrName) { 2049 llvm::SMLoc loc; 2050 StringRef ordering; 2051 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2052 return failure(); 2053 2054 // Replace the keyword `ordering` with an integer attribute. 2055 auto kind = symbolizeAtomicOrdering(ordering); 2056 if (!kind) { 2057 return parser.emitError(loc) 2058 << "'" << ordering << "' is an incorrect value of the '" << attrName 2059 << "' attribute"; 2060 } 2061 2062 auto value = static_cast<int64_t>(kind.getValue()); 2063 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2064 result.addAttribute(attrName, attr); 2065 2066 return success(); 2067 } 2068 2069 //===----------------------------------------------------------------------===// 2070 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2071 //===----------------------------------------------------------------------===// 2072 2073 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { 2074 p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' ' 2075 << op.ptr() << ", " << op.val() << ' ' 2076 << stringifyAtomicOrdering(op.ordering()) << ' '; 2077 p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"}); 2078 p << " : " << op.res().getType(); 2079 } 2080 2081 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2082 // attribute-dict? `:` type 2083 static ParseResult parseAtomicRMWOp(OpAsmParser &parser, 2084 OperationState &result) { 2085 Type type; 2086 OpAsmParser::OperandType ptr, val; 2087 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2088 parser.parseComma() || parser.parseOperand(val) || 2089 parseAtomicOrdering(parser, result, "ordering") || 2090 parser.parseOptionalAttrDict(result.attributes) || 2091 parser.parseColonType(type) || 2092 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2093 result.operands) || 2094 parser.resolveOperand(val, type, result.operands)) 2095 return failure(); 2096 2097 result.addTypes(type); 2098 return success(); 2099 } 2100 2101 static LogicalResult verify(AtomicRMWOp op) { 2102 auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>(); 2103 auto valType = op.val().getType(); 2104 if (valType != ptrType.getElementType()) 2105 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2106 "match type for operand #1"); 2107 auto resType = op.res().getType(); 2108 if (resType != valType) 2109 return op.emitOpError( 2110 "expected LLVM IR result type to match type for operand #1"); 2111 if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { 2112 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2113 return op.emitOpError("expected LLVM IR floating point type"); 2114 } else if (op.bin_op() == AtomicBinOp::xchg) { 2115 auto intType = valType.dyn_cast<IntegerType>(); 2116 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2117 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2118 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2119 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2120 !valType.isa<Float64Type>()) 2121 return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2122 } else { 2123 auto intType = valType.dyn_cast<IntegerType>(); 2124 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2125 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2126 intBitWidth != 64) 2127 return op.emitOpError("expected LLVM IR integer type"); 2128 } 2129 2130 if (static_cast<unsigned>(op.ordering()) < 2131 static_cast<unsigned>(AtomicOrdering::monotonic)) 2132 return op.emitOpError() 2133 << "expected at least '" 2134 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2135 << "' ordering"; 2136 2137 return success(); 2138 } 2139 2140 //===----------------------------------------------------------------------===// 2141 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2142 //===----------------------------------------------------------------------===// 2143 2144 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { 2145 p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", " 2146 << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' ' 2147 << stringifyAtomicOrdering(op.failure_ordering()); 2148 p.printOptionalAttrDict(op->getAttrs(), 2149 {"success_ordering", "failure_ordering"}); 2150 p << " : " << op.val().getType(); 2151 } 2152 2153 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2154 // keyword keyword attribute-dict? `:` type 2155 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, 2156 OperationState &result) { 2157 auto &builder = parser.getBuilder(); 2158 Type type; 2159 OpAsmParser::OperandType ptr, cmp, val; 2160 if (parser.parseOperand(ptr) || parser.parseComma() || 2161 parser.parseOperand(cmp) || parser.parseComma() || 2162 parser.parseOperand(val) || 2163 parseAtomicOrdering(parser, result, "success_ordering") || 2164 parseAtomicOrdering(parser, result, "failure_ordering") || 2165 parser.parseOptionalAttrDict(result.attributes) || 2166 parser.parseColonType(type) || 2167 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2168 result.operands) || 2169 parser.resolveOperand(cmp, type, result.operands) || 2170 parser.resolveOperand(val, type, result.operands)) 2171 return failure(); 2172 2173 auto boolType = IntegerType::get(builder.getContext(), 1); 2174 auto resultType = 2175 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2176 result.addTypes(resultType); 2177 2178 return success(); 2179 } 2180 2181 static LogicalResult verify(AtomicCmpXchgOp op) { 2182 auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>(); 2183 if (!ptrType) 2184 return op.emitOpError("expected LLVM IR pointer type for operand #0"); 2185 auto cmpType = op.cmp().getType(); 2186 auto valType = op.val().getType(); 2187 if (cmpType != ptrType.getElementType() || cmpType != valType) 2188 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2189 "match type for all other operands"); 2190 auto intType = valType.dyn_cast<IntegerType>(); 2191 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2192 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2193 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2194 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2195 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2196 return op.emitOpError("unexpected LLVM IR type"); 2197 if (op.success_ordering() < AtomicOrdering::monotonic || 2198 op.failure_ordering() < AtomicOrdering::monotonic) 2199 return op.emitOpError("ordering must be at least 'monotonic'"); 2200 if (op.failure_ordering() == AtomicOrdering::release || 2201 op.failure_ordering() == AtomicOrdering::acq_rel) 2202 return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2203 return success(); 2204 } 2205 2206 //===----------------------------------------------------------------------===// 2207 // Printer, parser and verifier for LLVM::FenceOp. 2208 //===----------------------------------------------------------------------===// 2209 2210 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2211 // attribute-dict? 2212 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) { 2213 StringAttr sScope; 2214 StringRef syncscopeKeyword = "syncscope"; 2215 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2216 if (parser.parseLParen() || 2217 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2218 parser.parseRParen()) 2219 return failure(); 2220 } else { 2221 result.addAttribute(syncscopeKeyword, 2222 parser.getBuilder().getStringAttr("")); 2223 } 2224 if (parseAtomicOrdering(parser, result, "ordering") || 2225 parser.parseOptionalAttrDict(result.attributes)) 2226 return failure(); 2227 return success(); 2228 } 2229 2230 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { 2231 StringRef syncscopeKeyword = "syncscope"; 2232 p << op.getOperationName() << ' '; 2233 if (!op->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2234 p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") "; 2235 p << stringifyAtomicOrdering(op.ordering()); 2236 } 2237 2238 static LogicalResult verify(FenceOp &op) { 2239 if (op.ordering() == AtomicOrdering::not_atomic || 2240 op.ordering() == AtomicOrdering::unordered || 2241 op.ordering() == AtomicOrdering::monotonic) 2242 return op.emitOpError("can be given only acquire, release, acq_rel, " 2243 "and seq_cst orderings"); 2244 return success(); 2245 } 2246 2247 //===----------------------------------------------------------------------===// 2248 // LLVMDialect initialization, type parsing, and registration. 2249 //===----------------------------------------------------------------------===// 2250 2251 void LLVMDialect::initialize() { 2252 addAttributes<FMFAttr, LoopOptionsAttr>(); 2253 2254 // clang-format off 2255 addTypes<LLVMVoidType, 2256 LLVMPPCFP128Type, 2257 LLVMX86MMXType, 2258 LLVMTokenType, 2259 LLVMLabelType, 2260 LLVMMetadataType, 2261 LLVMFunctionType, 2262 LLVMPointerType, 2263 LLVMFixedVectorType, 2264 LLVMScalableVectorType, 2265 LLVMArrayType, 2266 LLVMStructType>(); 2267 // clang-format on 2268 addOperations< 2269 #define GET_OP_LIST 2270 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2271 >(); 2272 2273 // Support unknown operations because not all LLVM operations are registered. 2274 allowUnknownOperations(); 2275 } 2276 2277 #define GET_OP_CLASSES 2278 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2279 2280 /// Parse a type registered to this dialect. 2281 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2282 return detail::parseType(parser); 2283 } 2284 2285 /// Print a type registered to this dialect. 2286 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2287 return detail::printType(type, os); 2288 } 2289 2290 LogicalResult LLVMDialect::verifyDataLayoutString( 2291 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2292 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2293 llvm::DataLayout::parse(descr); 2294 if (maybeDataLayout) 2295 return success(); 2296 2297 std::string message; 2298 llvm::raw_string_ostream messageStream(message); 2299 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2300 reportError("invalid data layout descriptor: " + messageStream.str()); 2301 return failure(); 2302 } 2303 2304 /// Verify LLVM dialect attributes. 2305 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2306 NamedAttribute attr) { 2307 // If the `llvm.loop` attribute is present, enforce the following structure, 2308 // which the module translation can assume. 2309 if (attr.first.strref() == LLVMDialect::getLoopAttrName()) { 2310 auto loopAttr = attr.second.dyn_cast<DictionaryAttr>(); 2311 if (!loopAttr) 2312 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2313 << "' to be a dictionary attribute"; 2314 Optional<NamedAttribute> parallelAccessGroup = 2315 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2316 if (parallelAccessGroup.hasValue()) { 2317 auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>(); 2318 if (!accessGroups) 2319 return op->emitOpError() 2320 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2321 << "' to be an array attribute"; 2322 for (Attribute attr : accessGroups) { 2323 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2324 if (!accessGroupRef) 2325 return op->emitOpError() 2326 << "expected '" << attr << "' to be a symbol reference"; 2327 StringRef metadataName = accessGroupRef.getRootReference(); 2328 auto metadataOp = 2329 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2330 op->getParentOp(), metadataName); 2331 if (!metadataOp) 2332 return op->emitOpError() 2333 << "expected '" << attr << "' to reference a metadata op"; 2334 StringRef accessGroupName = accessGroupRef.getLeafReference(); 2335 Operation *accessGroupOp = 2336 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2337 if (!accessGroupOp) 2338 return op->emitOpError() 2339 << "expected '" << attr << "' to reference an access_group op"; 2340 } 2341 } 2342 2343 Optional<NamedAttribute> loopOptions = 2344 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2345 if (loopOptions.hasValue() && !loopOptions->second.isa<LoopOptionsAttr>()) 2346 return op->emitOpError() 2347 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2348 << "' to be a `loopopts` attribute"; 2349 } 2350 2351 // If the data layout attribute is present, it must use the LLVM data layout 2352 // syntax. Try parsing it and report errors in case of failure. Users of this 2353 // attribute may assume it is well-formed and can pass it to the (asserting) 2354 // llvm::DataLayout constructor. 2355 if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2356 return success(); 2357 if (auto stringAttr = attr.second.dyn_cast<StringAttr>()) 2358 return verifyDataLayoutString( 2359 stringAttr.getValue(), 2360 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2361 2362 return op->emitOpError() << "expected '" 2363 << LLVM::LLVMDialect::getDataLayoutAttrName() 2364 << "' to be a string attribute"; 2365 } 2366 2367 /// Verify LLVMIR function argument attributes. 2368 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2369 unsigned regionIdx, 2370 unsigned argIdx, 2371 NamedAttribute argAttr) { 2372 // Check that llvm.noalias is a unit attribute. 2373 if (argAttr.first == LLVMDialect::getNoAliasAttrName() && 2374 !argAttr.second.isa<UnitAttr>()) 2375 return op->emitError() 2376 << "expected llvm.noalias argument attribute to be a unit attribute"; 2377 // Check that llvm.align is an integer attribute. 2378 if (argAttr.first == LLVMDialect::getAlignAttrName() && 2379 !argAttr.second.isa<IntegerAttr>()) 2380 return op->emitError() 2381 << "llvm.align argument attribute of non integer type"; 2382 return success(); 2383 } 2384 2385 //===----------------------------------------------------------------------===// 2386 // Utility functions. 2387 //===----------------------------------------------------------------------===// 2388 2389 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2390 StringRef name, StringRef value, 2391 LLVM::Linkage linkage) { 2392 assert(builder.getInsertionBlock() && 2393 builder.getInsertionBlock()->getParentOp() && 2394 "expected builder to point to a block constrained in an op"); 2395 auto module = 2396 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2397 assert(module && "builder points to an op outside of a module"); 2398 2399 // Create the global at the entry of the module. 2400 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2401 MLIRContext *ctx = builder.getContext(); 2402 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2403 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2404 loc, type, /*isConstant=*/true, linkage, name, 2405 builder.getStringAttr(value), /*alignment=*/0); 2406 2407 // Get the pointer to the first character in the global string. 2408 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2409 Value cst0 = builder.create<LLVM::ConstantOp>( 2410 loc, IntegerType::get(ctx, 64), 2411 builder.getIntegerAttr(builder.getIndexType(), 0)); 2412 return builder.create<LLVM::GEPOp>( 2413 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2414 ValueRange{cst0, cst0}); 2415 } 2416 2417 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2418 return op->hasTrait<OpTrait::SymbolTable>() && 2419 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2420 } 2421 2422 static constexpr const FastmathFlags FastmathFlagsList[] = { 2423 // clang-format off 2424 FastmathFlags::nnan, 2425 FastmathFlags::ninf, 2426 FastmathFlags::nsz, 2427 FastmathFlags::arcp, 2428 FastmathFlags::contract, 2429 FastmathFlags::afn, 2430 FastmathFlags::reassoc, 2431 FastmathFlags::fast, 2432 // clang-format on 2433 }; 2434 2435 void FMFAttr::print(DialectAsmPrinter &printer) const { 2436 printer << "fastmath<"; 2437 auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) { 2438 return bitEnumContains(this->getFlags(), flag); 2439 }); 2440 llvm::interleaveComma(flags, printer, 2441 [&](auto flag) { printer << stringifyEnum(flag); }); 2442 printer << ">"; 2443 } 2444 2445 Attribute FMFAttr::parse(MLIRContext *context, DialectAsmParser &parser, 2446 Type type) { 2447 if (failed(parser.parseLess())) 2448 return {}; 2449 2450 FastmathFlags flags = {}; 2451 if (failed(parser.parseOptionalGreater())) { 2452 do { 2453 StringRef elemName; 2454 if (failed(parser.parseKeyword(&elemName))) 2455 return {}; 2456 2457 auto elem = symbolizeFastmathFlags(elemName); 2458 if (!elem) { 2459 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2460 << elemName; 2461 return {}; 2462 } 2463 2464 flags = flags | *elem; 2465 } while (succeeded(parser.parseOptionalComma())); 2466 2467 if (failed(parser.parseGreater())) 2468 return {}; 2469 } 2470 2471 return FMFAttr::get(parser.getBuilder().getContext(), flags); 2472 } 2473 2474 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2475 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2476 2477 template <typename T> 2478 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2479 Optional<T> value) { 2480 auto option = llvm::find_if( 2481 options, [tag](auto option) { return option.first == tag; }); 2482 if (option != options.end()) { 2483 if (value.hasValue()) 2484 option->second = *value; 2485 else 2486 options.erase(option); 2487 } else { 2488 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2489 } 2490 return *this; 2491 } 2492 2493 LoopOptionsAttrBuilder & 2494 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 2495 return setOption(LoopOptionCase::disable_licm, value); 2496 } 2497 2498 /// Set the `interleave_count` option to the provided value. If no value 2499 /// is provided the option is deleted. 2500 LoopOptionsAttrBuilder & 2501 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 2502 return setOption(LoopOptionCase::interleave_count, count); 2503 } 2504 2505 /// Set the `disable_unroll` option to the provided value. If no value 2506 /// is provided the option is deleted. 2507 LoopOptionsAttrBuilder & 2508 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 2509 return setOption(LoopOptionCase::disable_unroll, value); 2510 } 2511 2512 /// Set the `disable_pipeline` option to the provided value. If no value 2513 /// is provided the option is deleted. 2514 LoopOptionsAttrBuilder & 2515 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 2516 return setOption(LoopOptionCase::disable_pipeline, value); 2517 } 2518 2519 /// Set the `pipeline_initiation_interval` option to the provided value. 2520 /// If no value is provided the option is deleted. 2521 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 2522 Optional<uint64_t> count) { 2523 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 2524 } 2525 2526 template <typename T> 2527 static Optional<T> 2528 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 2529 LoopOptionCase option) { 2530 auto it = 2531 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 2532 return optionPair.first < option; 2533 }); 2534 if (it == options.end()) 2535 return {}; 2536 return static_cast<T>(it->second); 2537 } 2538 2539 Optional<bool> LoopOptionsAttr::disableUnroll() { 2540 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 2541 } 2542 2543 Optional<bool> LoopOptionsAttr::disableLICM() { 2544 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 2545 } 2546 2547 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 2548 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 2549 } 2550 2551 /// Build the LoopOptions Attribute from a sorted array of individual options. 2552 LoopOptionsAttr LoopOptionsAttr::get( 2553 MLIRContext *context, 2554 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 2555 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 2556 "LoopOptionsAttr ctor expects a sorted options array"); 2557 return Base::get(context, sortedOptions); 2558 } 2559 2560 /// Build the LoopOptions Attribute from a sorted array of individual options. 2561 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 2562 LoopOptionsAttrBuilder &optionBuilders) { 2563 llvm::sort(optionBuilders.options, llvm::less_first()); 2564 return Base::get(context, optionBuilders.options); 2565 } 2566 2567 void LoopOptionsAttr::print(DialectAsmPrinter &printer) const { 2568 printer << getMnemonic() << "<"; 2569 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 2570 printer << stringifyEnum(option.first) << " = "; 2571 switch (option.first) { 2572 case LoopOptionCase::disable_licm: 2573 case LoopOptionCase::disable_unroll: 2574 case LoopOptionCase::disable_pipeline: 2575 printer << (option.second ? "true" : "false"); 2576 break; 2577 case LoopOptionCase::interleave_count: 2578 case LoopOptionCase::pipeline_initiation_interval: 2579 printer << option.second; 2580 break; 2581 } 2582 }); 2583 printer << ">"; 2584 } 2585 2586 Attribute LoopOptionsAttr::parse(MLIRContext *context, DialectAsmParser &parser, 2587 Type type) { 2588 if (failed(parser.parseLess())) 2589 return {}; 2590 2591 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 2592 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 2593 do { 2594 StringRef optionName; 2595 if (parser.parseKeyword(&optionName)) 2596 return {}; 2597 2598 auto option = symbolizeLoopOptionCase(optionName); 2599 if (!option) { 2600 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 2601 << optionName; 2602 return {}; 2603 } 2604 if (!seenOptions.insert(*option).second) { 2605 parser.emitError(parser.getNameLoc(), "loop option present twice"); 2606 return {}; 2607 } 2608 if (failed(parser.parseEqual())) 2609 return {}; 2610 2611 int64_t value; 2612 switch (*option) { 2613 case LoopOptionCase::disable_licm: 2614 case LoopOptionCase::disable_unroll: 2615 case LoopOptionCase::disable_pipeline: 2616 if (succeeded(parser.parseOptionalKeyword("true"))) 2617 value = 1; 2618 else if (succeeded(parser.parseOptionalKeyword("false"))) 2619 value = 0; 2620 else { 2621 parser.emitError(parser.getNameLoc(), 2622 "expected boolean value 'true' or 'false'"); 2623 return {}; 2624 } 2625 break; 2626 case LoopOptionCase::interleave_count: 2627 case LoopOptionCase::pipeline_initiation_interval: 2628 if (failed(parser.parseInteger(value))) { 2629 parser.emitError(parser.getNameLoc(), "expected integer value"); 2630 return {}; 2631 } 2632 break; 2633 } 2634 options.push_back(std::make_pair(*option, value)); 2635 } while (succeeded(parser.parseOptionalComma())); 2636 if (failed(parser.parseGreater())) 2637 return {}; 2638 2639 llvm::sort(options, llvm::less_first()); 2640 return get(parser.getBuilder().getContext(), options); 2641 } 2642 2643 Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser, 2644 Type type) const { 2645 if (type) { 2646 parser.emitError(parser.getNameLoc(), "unexpected type"); 2647 return {}; 2648 } 2649 StringRef attrKind; 2650 if (parser.parseKeyword(&attrKind)) 2651 return {}; 2652 { 2653 Attribute attr; 2654 auto parseResult = 2655 generatedAttributeParser(getContext(), parser, attrKind, type, attr); 2656 if (parseResult.hasValue()) 2657 return attr; 2658 } 2659 parser.emitError(parser.getNameLoc(), "unknown attribute type: ") << attrKind; 2660 return {}; 2661 } 2662 2663 void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { 2664 if (succeeded(generatedAttributePrinter(attr, os))) 2665 return; 2666 llvm_unreachable("Unknown attribute type"); 2667 } 2668