1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Mutex.h" 33 #include "llvm/Support/SourceMgr.h" 34 35 #include <numeric> 36 37 using namespace mlir; 38 using namespace mlir::LLVM; 39 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 40 41 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 42 43 static constexpr const char kVolatileAttrName[] = "volatile_"; 44 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 45 static constexpr const char kElemTypeAttrName[] = "elem_type"; 46 47 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 48 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 49 #define GET_ATTRDEF_CLASSES 50 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 51 52 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 53 SmallVector<NamedAttribute, 8> filteredAttrs( 54 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 55 if (attr.getName() == "fastmathFlags") { 56 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 57 return defAttr != attr.getValue(); 58 } 59 return true; 60 })); 61 return filteredAttrs; 62 } 63 64 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 65 NamedAttrList &result) { 66 return parser.parseOptionalAttrDict(result); 67 } 68 69 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 70 DictionaryAttr attrs) { 71 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 72 } 73 74 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 75 /// fully defined llvm.func. 76 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 77 Operation *op, 78 SymbolTableCollection &symbolTable) { 79 StringRef name = symbol.getValue(); 80 auto func = 81 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 82 if (!func) 83 return op->emitOpError("'") 84 << name << "' does not reference a valid LLVM function"; 85 if (func.isExternal()) 86 return op->emitOpError("'") << name << "' does not have a definition"; 87 return success(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // Printing/parsing for LLVM::CmpOp. 92 //===----------------------------------------------------------------------===// 93 94 void ICmpOp::print(OpAsmPrinter &p) { 95 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 96 << ", " << getOperand(1); 97 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 98 p << " : " << getLhs().getType(); 99 } 100 101 void FCmpOp::print(OpAsmPrinter &p) { 102 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 103 << ", " << getOperand(1); 104 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 105 p << " : " << getLhs().getType(); 106 } 107 108 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 109 // attribute-dict? `:` type 110 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 111 // attribute-dict? `:` type 112 template <typename CmpPredicateType> 113 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 114 Builder &builder = parser.getBuilder(); 115 116 StringAttr predicateAttr; 117 OpAsmParser::UnresolvedOperand lhs, rhs; 118 Type type; 119 SMLoc predicateLoc, trailingTypeLoc; 120 if (parser.getCurrentLocation(&predicateLoc) || 121 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 122 parser.parseOperand(lhs) || parser.parseComma() || 123 parser.parseOperand(rhs) || 124 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 125 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 126 parser.resolveOperand(lhs, type, result.operands) || 127 parser.resolveOperand(rhs, type, result.operands)) 128 return failure(); 129 130 // Replace the string attribute `predicate` with an integer attribute. 131 int64_t predicateValue = 0; 132 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 133 Optional<ICmpPredicate> predicate = 134 symbolizeICmpPredicate(predicateAttr.getValue()); 135 if (!predicate) 136 return parser.emitError(predicateLoc) 137 << "'" << predicateAttr.getValue() 138 << "' is an incorrect value of the 'predicate' attribute"; 139 predicateValue = static_cast<int64_t>(predicate.getValue()); 140 } else { 141 Optional<FCmpPredicate> predicate = 142 symbolizeFCmpPredicate(predicateAttr.getValue()); 143 if (!predicate) 144 return parser.emitError(predicateLoc) 145 << "'" << predicateAttr.getValue() 146 << "' is an incorrect value of the 'predicate' attribute"; 147 predicateValue = static_cast<int64_t>(predicate.getValue()); 148 } 149 150 result.attributes.set("predicate", 151 parser.getBuilder().getI64IntegerAttr(predicateValue)); 152 153 // The result type is either i1 or a vector type <? x i1> if the inputs are 154 // vectors. 155 Type resultType = IntegerType::get(builder.getContext(), 1); 156 if (!isCompatibleType(type)) 157 return parser.emitError(trailingTypeLoc, 158 "expected LLVM dialect-compatible type"); 159 if (LLVM::isCompatibleVectorType(type)) { 160 if (LLVM::isScalableVectorType(type)) { 161 resultType = LLVM::getVectorType( 162 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 163 /*isScalable=*/true); 164 } else { 165 resultType = LLVM::getVectorType( 166 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 167 /*isScalable=*/false); 168 } 169 } 170 171 result.addTypes({resultType}); 172 return success(); 173 } 174 175 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 176 return parseCmpOp<ICmpPredicate>(parser, result); 177 } 178 179 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 180 return parseCmpOp<FCmpPredicate>(parser, result); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // Printing, parsing and verification for LLVM::AllocaOp. 185 //===----------------------------------------------------------------------===// 186 187 void AllocaOp::print(OpAsmPrinter &p) { 188 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 189 if (!elemTy) 190 elemTy = *getElemType(); 191 192 auto funcTy = 193 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 194 195 p << ' ' << getArraySize() << " x " << elemTy; 196 if (getAlignment().hasValue() && *getAlignment() != 0) 197 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName}); 198 else 199 p.printOptionalAttrDict((*this)->getAttrs(), 200 {"alignment", kElemTypeAttrName}); 201 p << " : " << funcTy; 202 } 203 204 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 205 // `:` type `,` type 206 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 207 OpAsmParser::UnresolvedOperand arraySize; 208 Type type, elemType; 209 SMLoc trailingTypeLoc; 210 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 211 parser.parseType(elemType) || 212 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 213 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 214 return failure(); 215 216 Optional<NamedAttribute> alignmentAttr = 217 result.attributes.getNamed("alignment"); 218 if (alignmentAttr.hasValue()) { 219 auto alignmentInt = 220 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 221 if (!alignmentInt) 222 return parser.emitError(parser.getNameLoc(), 223 "expected integer alignment"); 224 if (alignmentInt.getValue().isNullValue()) 225 result.attributes.erase("alignment"); 226 } 227 228 // Extract the result type from the trailing function type. 229 auto funcType = type.dyn_cast<FunctionType>(); 230 if (!funcType || funcType.getNumInputs() != 1 || 231 funcType.getNumResults() != 1) 232 return parser.emitError( 233 trailingTypeLoc, 234 "expected trailing function type with one argument and one result"); 235 236 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 237 return failure(); 238 239 Type resultType = funcType.getResult(0); 240 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) { 241 if (ptrResultType.isOpaque()) 242 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 243 } 244 245 result.addTypes({funcType.getResult(0)}); 246 return success(); 247 } 248 249 /// Checks that the elemental type is present in either the pointer type or 250 /// the attribute, but not both. 251 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, 252 Optional<Type> ptrElementType) { 253 if (ptrType.isOpaque() && !ptrElementType.hasValue()) { 254 return op->emitOpError() << "expected '" << kElemTypeAttrName 255 << "' attribute if opaque pointer type is used"; 256 } 257 if (!ptrType.isOpaque() && ptrElementType.hasValue()) { 258 return op->emitOpError() 259 << "unexpected '" << kElemTypeAttrName 260 << "' attribute when non-opaque pointer type is used"; 261 } 262 return success(); 263 } 264 265 LogicalResult AllocaOp::verify() { 266 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(), 267 getElemType()); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // LLVM::BrOp 272 //===----------------------------------------------------------------------===// 273 274 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 275 assert(index == 0 && "invalid successor index"); 276 return SuccessorOperands(getDestOperandsMutable()); 277 } 278 279 //===----------------------------------------------------------------------===// 280 // LLVM::CondBrOp 281 //===----------------------------------------------------------------------===// 282 283 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 284 assert(index < getNumSuccessors() && "invalid successor index"); 285 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 286 : getFalseDestOperandsMutable()); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // LLVM::SwitchOp 291 //===----------------------------------------------------------------------===// 292 293 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 294 Block *defaultDestination, ValueRange defaultOperands, 295 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 296 ArrayRef<ValueRange> caseOperands, 297 ArrayRef<int32_t> branchWeights) { 298 ElementsAttr caseValuesAttr; 299 if (!caseValues.empty()) 300 caseValuesAttr = builder.getI32VectorAttr(caseValues); 301 302 ElementsAttr weightsAttr; 303 if (!branchWeights.empty()) 304 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 305 306 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 307 weightsAttr, defaultDestination, caseDestinations); 308 } 309 310 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 311 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 312 static ParseResult parseSwitchOpCases( 313 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 314 SmallVectorImpl<Block *> &caseDestinations, 315 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 316 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 317 SmallVector<APInt> values; 318 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 319 do { 320 int64_t value = 0; 321 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 322 if (values.empty() && !integerParseResult.hasValue()) 323 return success(); 324 325 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 326 return failure(); 327 values.push_back(APInt(bitWidth, value)); 328 329 Block *destination; 330 SmallVector<OpAsmParser::UnresolvedOperand> operands; 331 SmallVector<Type> operandTypes; 332 if (parser.parseColon() || parser.parseSuccessor(destination)) 333 return failure(); 334 if (!parser.parseOptionalLParen()) { 335 if (parser.parseRegionArgumentList(operands) || 336 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 337 return failure(); 338 } 339 caseDestinations.push_back(destination); 340 caseOperands.emplace_back(operands); 341 caseOperandTypes.emplace_back(operandTypes); 342 } while (!parser.parseOptionalComma()); 343 344 ShapedType caseValueType = 345 VectorType::get(static_cast<int64_t>(values.size()), flagType); 346 caseValues = DenseIntElementsAttr::get(caseValueType, values); 347 return success(); 348 } 349 350 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 351 ElementsAttr caseValues, 352 SuccessorRange caseDestinations, 353 OperandRangeRange caseOperands, 354 const TypeRangeRange &caseOperandTypes) { 355 if (!caseValues) 356 return; 357 358 size_t index = 0; 359 llvm::interleave( 360 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 361 [&](auto i) { 362 p << " "; 363 p << std::get<0>(i).getLimitedValue(); 364 p << ": "; 365 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 366 }, 367 [&] { 368 p << ','; 369 p.printNewline(); 370 }); 371 p.printNewline(); 372 } 373 374 LogicalResult SwitchOp::verify() { 375 if ((!getCaseValues() && !getCaseDestinations().empty()) || 376 (getCaseValues() && 377 getCaseValues()->size() != 378 static_cast<int64_t>(getCaseDestinations().size()))) 379 return emitOpError("expects number of case values to match number of " 380 "case destinations"); 381 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 382 return emitError("expects number of branch weights to match number of " 383 "successors: ") 384 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 385 return success(); 386 } 387 388 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 389 assert(index < getNumSuccessors() && "invalid successor index"); 390 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 391 : getCaseOperandsMutable(index - 1)); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // Code for LLVM::GEPOp. 396 //===----------------------------------------------------------------------===// 397 398 constexpr int GEPOp::kDynamicIndex; 399 400 /// Populates `indices` with positions of GEP indices that would correspond to 401 /// LLVMStructTypes potentially nested in the given type. The type currently 402 /// visited gets `currentIndex` and LLVM container types are visited 403 /// recursively. The recursion is bounded and takes care of recursive types by 404 /// means of the `visited` set. 405 static void recordStructIndices(Type type, unsigned currentIndex, 406 SmallVectorImpl<unsigned> &indices, 407 SmallVectorImpl<unsigned> *structSizes, 408 SmallPtrSet<Type, 4> &visited) { 409 if (visited.contains(type)) 410 return; 411 412 visited.insert(type); 413 414 llvm::TypeSwitch<Type>(type) 415 .Case<LLVMStructType>([&](LLVMStructType structType) { 416 indices.push_back(currentIndex); 417 if (structSizes) 418 structSizes->push_back(structType.getBody().size()); 419 for (Type elementType : structType.getBody()) 420 recordStructIndices(elementType, currentIndex + 1, indices, 421 structSizes, visited); 422 }) 423 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 424 LLVMArrayType>([&](auto containerType) { 425 recordStructIndices(containerType.getElementType(), currentIndex + 1, 426 indices, structSizes, visited); 427 }); 428 } 429 430 /// Populates `indices` with positions of GEP indices that correspond to 431 /// LLVMStructTypes potentially nested in the given `baseGEPType`, which must 432 /// be either an LLVMPointer type or a vector thereof. If `structSizes` is 433 /// provided, it is populated with sizes of the indexed structs for bounds 434 /// verification purposes. 435 void GEPOp::findKnownStructIndices(Type sourceElementType, 436 SmallVectorImpl<unsigned> &indices, 437 SmallVectorImpl<unsigned> *structSizes) { 438 SmallPtrSet<Type, 4> visited; 439 recordStructIndices(sourceElementType, /*currentIndex=*/1, indices, 440 structSizes, visited); 441 } 442 443 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 444 Value basePtr, ValueRange operands, 445 ArrayRef<NamedAttribute> attributes) { 446 build(builder, result, resultType, basePtr, operands, 447 SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex), 448 attributes); 449 } 450 451 /// Returns the elemental type of any LLVM-compatible vector type or self. 452 static Type extractVectorElementType(Type type) { 453 if (auto vectorType = type.dyn_cast<VectorType>()) 454 return vectorType.getElementType(); 455 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 456 return scalableVectorType.getElementType(); 457 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 458 return fixedVectorType.getElementType(); 459 return type; 460 } 461 462 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 463 Value basePtr, ValueRange indices, 464 ArrayRef<int32_t> structIndices, 465 ArrayRef<NamedAttribute> attributes) { 466 auto ptrType = 467 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>(); 468 assert(!ptrType.isOpaque() && 469 "expected non-opaque pointer, provide elementType explicitly when " 470 "opaque pointers are used"); 471 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, 472 structIndices, attributes); 473 } 474 475 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 476 Type elementType, Value basePtr, ValueRange indices, 477 ArrayRef<int32_t> structIndices, 478 ArrayRef<NamedAttribute> attributes) { 479 SmallVector<Value> remainingIndices; 480 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 481 structIndices.end()); 482 SmallVector<unsigned> structRelatedPositions; 483 findKnownStructIndices(elementType, structRelatedPositions); 484 485 SmallVector<unsigned> operandsToErase; 486 for (unsigned pos : structRelatedPositions) { 487 // GEP may not be indexing as deep as some structs are located. 488 if (pos >= structIndices.size()) 489 continue; 490 491 // If the index is already static, it's fine. 492 if (structIndices[pos] != kDynamicIndex) 493 continue; 494 495 // Find the corresponding operand. 496 unsigned operandPos = 497 std::count(structIndices.begin(), std::next(structIndices.begin(), pos), 498 kDynamicIndex); 499 500 // Extract the constant value from the operand and put it into the attribute 501 // instead. 502 APInt staticIndexValue; 503 bool matched = 504 matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue)); 505 (void)matched; 506 assert(matched && "index into a struct must be a constant"); 507 assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) && 508 "struct index underflows 32-bit integer"); 509 assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) && 510 "struct index overflows 32-bit integer"); 511 auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 512 updatedStructIndices[pos] = staticIndex; 513 operandsToErase.push_back(operandPos); 514 } 515 516 for (unsigned i = 0, e = indices.size(); i < e; ++i) { 517 if (!llvm::is_contained(operandsToErase, i)) 518 remainingIndices.push_back(indices[i]); 519 } 520 521 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 522 updatedStructIndices, kDynamicIndex)) && 523 "expected as many index operands as dynamic index attr elements"); 524 525 result.addTypes(resultType); 526 result.addAttributes(attributes); 527 result.addAttribute("structIndices", 528 builder.getI32TensorAttr(updatedStructIndices)); 529 if (extractVectorElementType(basePtr.getType()) 530 .cast<LLVMPointerType>() 531 .isOpaque()) 532 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 533 result.addOperands(basePtr); 534 result.addOperands(remainingIndices); 535 } 536 537 static ParseResult 538 parseGEPIndices(OpAsmParser &parser, 539 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 540 DenseIntElementsAttr &structIndices) { 541 SmallVector<int32_t> constantIndices; 542 do { 543 int32_t constantIndex; 544 OptionalParseResult parsedInteger = 545 parser.parseOptionalInteger(constantIndex); 546 if (parsedInteger.hasValue()) { 547 if (failed(parsedInteger.getValue())) 548 return failure(); 549 constantIndices.push_back(constantIndex); 550 continue; 551 } 552 553 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 554 if (failed(parser.parseOperand(indices.emplace_back()))) 555 return failure(); 556 } while (succeeded(parser.parseOptionalComma())); 557 558 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 559 return success(); 560 } 561 562 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 563 OperandRange indices, 564 DenseIntElementsAttr structIndices) { 565 unsigned operandIdx = 0; 566 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 567 [&](int32_t cst) { 568 if (cst == LLVM::GEPOp::kDynamicIndex) 569 printer.printOperand(indices[operandIdx++]); 570 else 571 printer << cst; 572 }); 573 } 574 575 LogicalResult LLVM::GEPOp::verify() { 576 if (failed(verifyOpaquePtr( 577 getOperation(), 578 extractVectorElementType(getType()).cast<LLVMPointerType>(), 579 getElemType()))) 580 return failure(); 581 582 SmallVector<unsigned> indices; 583 SmallVector<unsigned> structSizes; 584 findKnownStructIndices(getSourceElementType(), indices, &structSizes); 585 DenseIntElementsAttr structIndices = getStructIndices(); 586 for (unsigned i : llvm::seq<unsigned>(0, indices.size())) { 587 unsigned index = indices[i]; 588 // GEP may not be indexing as deep as some structs nested in the type. 589 if (index >= structIndices.getNumElements()) 590 continue; 591 592 int32_t staticIndex = structIndices.getValues<int32_t>()[index]; 593 if (staticIndex == LLVM::GEPOp::kDynamicIndex) 594 return emitOpError() << "expected index " << index 595 << " indexing a struct to be constant"; 596 if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i]) 597 return emitOpError() << "index " << index 598 << " indexing a struct is out of bounds"; 599 } 600 return success(); 601 } 602 603 Type LLVM::GEPOp::getSourceElementType() { 604 if (Optional<Type> elemType = getElemType()) 605 return *elemType; 606 607 return extractVectorElementType(getBase().getType()) 608 .cast<LLVMPointerType>() 609 .getElementType(); 610 } 611 612 //===----------------------------------------------------------------------===// 613 // Builder, printer and parser for for LLVM::LoadOp. 614 //===----------------------------------------------------------------------===// 615 616 LogicalResult verifySymbolAttribute( 617 Operation *op, StringRef attributeName, 618 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 619 verifySymbolType) { 620 if (Attribute attribute = op->getAttr(attributeName)) { 621 // The attribute is already verified to be a symbol ref array attribute via 622 // a constraint in the operation definition. 623 for (SymbolRefAttr symbolRef : 624 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 625 StringAttr metadataName = symbolRef.getRootReference(); 626 StringAttr symbolName = symbolRef.getLeafReference(); 627 // We want @metadata::@symbol, not just @symbol 628 if (metadataName == symbolName) { 629 return op->emitOpError() << "expected '" << symbolRef 630 << "' to specify a fully qualified reference"; 631 } 632 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 633 op->getParentOp(), metadataName); 634 if (!metadataOp) 635 return op->emitOpError() 636 << "expected '" << symbolRef << "' to reference a metadata op"; 637 Operation *symbolOp = 638 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 639 if (!symbolOp) 640 return op->emitOpError() 641 << "expected '" << symbolRef << "' to be a valid reference"; 642 if (failed(verifySymbolType(symbolOp, symbolRef))) { 643 return failure(); 644 } 645 } 646 } 647 return success(); 648 } 649 650 // Verifies that metadata ops are wired up properly. 651 template <typename OpTy> 652 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 653 auto verifySymbolType = [op](Operation *symbolOp, 654 SymbolRefAttr symbolRef) -> LogicalResult { 655 if (!isa<OpTy>(symbolOp)) { 656 return op->emitOpError() 657 << "expected '" << symbolRef << "' to resolve to a " 658 << OpTy::getOperationName(); 659 } 660 return success(); 661 }; 662 663 return verifySymbolAttribute(op, attributeName, verifySymbolType); 664 } 665 666 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 667 // access_groups 668 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 669 op, LLVMDialect::getAccessGroupsAttrName()))) 670 return failure(); 671 672 // alias_scopes 673 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 674 op, LLVMDialect::getAliasScopesAttrName()))) 675 return failure(); 676 677 // noalias_scopes 678 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 679 op, LLVMDialect::getNoAliasScopesAttrName()))) 680 return failure(); 681 682 return success(); 683 } 684 685 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 686 687 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 688 Value addr, unsigned alignment, bool isVolatile, 689 bool isNonTemporal) { 690 result.addOperands(addr); 691 result.addTypes(t); 692 if (isVolatile) 693 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 694 if (isNonTemporal) 695 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 696 if (alignment != 0) 697 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 698 } 699 700 void LoadOp::print(OpAsmPrinter &p) { 701 p << ' '; 702 if (getVolatile_()) 703 p << "volatile "; 704 p << getAddr(); 705 p.printOptionalAttrDict((*this)->getAttrs(), 706 {kVolatileAttrName, kElemTypeAttrName}); 707 p << " : " << getAddr().getType(); 708 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 709 p << " -> " << getType(); 710 } 711 712 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 713 // the resulting type if any, null type if opaque pointers are used, and None 714 // if the given type is not the pointer type. 715 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type, 716 SMLoc trailingTypeLoc) { 717 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 718 if (!llvmTy) { 719 parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); 720 return llvm::None; 721 } 722 return llvmTy.getElementType(); 723 } 724 725 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 726 // (`->` type)? 727 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 728 OpAsmParser::UnresolvedOperand addr; 729 Type type; 730 SMLoc trailingTypeLoc; 731 732 if (succeeded(parser.parseOptionalKeyword("volatile"))) 733 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 734 735 if (parser.parseOperand(addr) || 736 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 737 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 738 parser.resolveOperand(addr, type, result.operands)) 739 return failure(); 740 741 Optional<Type> elemTy = 742 getLoadStoreElementType(parser, type, trailingTypeLoc); 743 if (!elemTy) 744 return failure(); 745 if (*elemTy) { 746 result.addTypes(*elemTy); 747 return success(); 748 } 749 750 Type trailingType; 751 if (parser.parseArrow() || parser.parseType(trailingType)) 752 return failure(); 753 result.addTypes(trailingType); 754 return success(); 755 } 756 757 //===----------------------------------------------------------------------===// 758 // Builder, printer and parser for LLVM::StoreOp. 759 //===----------------------------------------------------------------------===// 760 761 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 762 763 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 764 Value addr, unsigned alignment, bool isVolatile, 765 bool isNonTemporal) { 766 result.addOperands({value, addr}); 767 result.addTypes({}); 768 if (isVolatile) 769 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 770 if (isNonTemporal) 771 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 772 if (alignment != 0) 773 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 774 } 775 776 void StoreOp::print(OpAsmPrinter &p) { 777 p << ' '; 778 if (getVolatile_()) 779 p << "volatile "; 780 p << getValue() << ", " << getAddr(); 781 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 782 p << " : "; 783 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 784 p << getValue().getType() << ", "; 785 p << getAddr().getType(); 786 } 787 788 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 789 // attribute-dict? `:` type (`,` type)? 790 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 791 OpAsmParser::UnresolvedOperand addr, value; 792 Type type; 793 SMLoc trailingTypeLoc; 794 795 if (succeeded(parser.parseOptionalKeyword("volatile"))) 796 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 797 798 if (parser.parseOperand(value) || parser.parseComma() || 799 parser.parseOperand(addr) || 800 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 801 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 802 return failure(); 803 804 Type operandType; 805 if (succeeded(parser.parseOptionalComma())) { 806 operandType = type; 807 if (parser.parseType(type)) 808 return failure(); 809 } else { 810 Optional<Type> maybeOperandType = 811 getLoadStoreElementType(parser, type, trailingTypeLoc); 812 if (!maybeOperandType) 813 return failure(); 814 operandType = *maybeOperandType; 815 } 816 817 if (parser.resolveOperand(value, operandType, result.operands) || 818 parser.resolveOperand(addr, type, result.operands)) 819 return failure(); 820 821 return success(); 822 } 823 824 ///===---------------------------------------------------------------------===// 825 /// LLVM::InvokeOp 826 ///===---------------------------------------------------------------------===// 827 828 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 829 assert(index < getNumSuccessors() && "invalid successor index"); 830 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 831 : getUnwindDestOperandsMutable()); 832 } 833 834 LogicalResult InvokeOp::verify() { 835 if (getNumResults() > 1) 836 return emitOpError("must have 0 or 1 result"); 837 838 Block *unwindDest = getUnwindDest(); 839 if (unwindDest->empty()) 840 return emitError("must have at least one operation in unwind destination"); 841 842 // In unwind destination, first operation must be LandingpadOp 843 if (!isa<LandingpadOp>(unwindDest->front())) 844 return emitError("first operation in unwind destination should be a " 845 "llvm.landingpad operation"); 846 847 return success(); 848 } 849 850 void InvokeOp::print(OpAsmPrinter &p) { 851 auto callee = getCallee(); 852 bool isDirect = callee.hasValue(); 853 854 p << ' '; 855 856 // Either function name or pointer 857 if (isDirect) 858 p.printSymbolName(callee.getValue()); 859 else 860 p << getOperand(0); 861 862 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 863 p << " to "; 864 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 865 p << " unwind "; 866 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 867 868 p.printOptionalAttrDict((*this)->getAttrs(), 869 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 870 p << " : "; 871 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 872 getResultTypes()); 873 } 874 875 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 876 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 877 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 878 /// attribute-dict? `:` function-type 879 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 880 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 881 FunctionType funcType; 882 SymbolRefAttr funcAttr; 883 SMLoc trailingTypeLoc; 884 Block *normalDest, *unwindDest; 885 SmallVector<Value, 4> normalOperands, unwindOperands; 886 Builder &builder = parser.getBuilder(); 887 888 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 889 // case of an indirect call, there will be 1 operand before `(`. In case of a 890 // direct call, there will be no operands and the parser will stop at the 891 // function identifier without complaining. 892 if (parser.parseOperandList(operands)) 893 return failure(); 894 bool isDirect = operands.empty(); 895 896 // Optionally parse a function identifier. 897 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 898 return failure(); 899 900 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 901 parser.parseKeyword("to") || 902 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 903 parser.parseKeyword("unwind") || 904 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 905 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 906 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 907 return failure(); 908 909 if (isDirect) { 910 // Make sure types match. 911 if (parser.resolveOperands(operands, funcType.getInputs(), 912 parser.getNameLoc(), result.operands)) 913 return failure(); 914 result.addTypes(funcType.getResults()); 915 } else { 916 // Construct the LLVM IR Dialect function type that the first operand 917 // should match. 918 if (funcType.getNumResults() > 1) 919 return parser.emitError(trailingTypeLoc, 920 "expected function with 0 or 1 result"); 921 922 Type llvmResultType; 923 if (funcType.getNumResults() == 0) { 924 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 925 } else { 926 llvmResultType = funcType.getResult(0); 927 if (!isCompatibleType(llvmResultType)) 928 return parser.emitError(trailingTypeLoc, 929 "expected result to have LLVM type"); 930 } 931 932 SmallVector<Type, 8> argTypes; 933 argTypes.reserve(funcType.getNumInputs()); 934 for (Type ty : funcType.getInputs()) { 935 if (isCompatibleType(ty)) 936 argTypes.push_back(ty); 937 else 938 return parser.emitError(trailingTypeLoc, 939 "expected LLVM types as inputs"); 940 } 941 942 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 943 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 944 945 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 946 947 // Make sure that the first operand (indirect callee) matches the wrapped 948 // LLVM IR function type, and that the types of the other call operands 949 // match the types of the function arguments. 950 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 951 parser.resolveOperands(funcArguments, funcType.getInputs(), 952 parser.getNameLoc(), result.operands)) 953 return failure(); 954 955 result.addTypes(llvmResultType); 956 } 957 result.addSuccessors({normalDest, unwindDest}); 958 result.addOperands(normalOperands); 959 result.addOperands(unwindOperands); 960 961 result.addAttribute( 962 InvokeOp::getOperandSegmentSizeAttr(), 963 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 964 static_cast<int32_t>(normalOperands.size()), 965 static_cast<int32_t>(unwindOperands.size())})); 966 return success(); 967 } 968 969 ///===----------------------------------------------------------------------===// 970 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 971 ///===----------------------------------------------------------------------===// 972 973 LogicalResult LandingpadOp::verify() { 974 Value value; 975 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 976 if (!func.getPersonality().hasValue()) 977 return emitError( 978 "llvm.landingpad needs to be in a function with a personality"); 979 } 980 981 if (!getCleanup() && getOperands().empty()) 982 return emitError("landingpad instruction expects at least one clause or " 983 "cleanup attribute"); 984 985 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 986 value = getOperand(idx); 987 bool isFilter = value.getType().isa<LLVMArrayType>(); 988 if (isFilter) { 989 // FIXME: Verify filter clauses when arrays are appropriately handled 990 } else { 991 // catch - global addresses only. 992 // Bitcast ops should have global addresses as their args. 993 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 994 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 995 continue; 996 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 997 << "global addresses expected as operand to " 998 "bitcast used in clauses for landingpad"; 999 } 1000 // NullOp and AddressOfOp allowed 1001 if (value.getDefiningOp<NullOp>()) 1002 continue; 1003 if (value.getDefiningOp<AddressOfOp>()) 1004 continue; 1005 return emitError("clause #") 1006 << idx << " is not a known constant - null, addressof, bitcast"; 1007 } 1008 } 1009 return success(); 1010 } 1011 1012 void LandingpadOp::print(OpAsmPrinter &p) { 1013 p << (getCleanup() ? " cleanup " : " "); 1014 1015 // Clauses 1016 for (auto value : getOperands()) { 1017 // Similar to llvm - if clause is an array type then it is filter 1018 // clause else catch clause 1019 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 1020 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1021 << value.getType() << ") "; 1022 } 1023 1024 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1025 1026 p << ": " << getType(); 1027 } 1028 1029 /// <operation> ::= `llvm.landingpad` `cleanup`? 1030 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1031 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1032 // Check for cleanup 1033 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1034 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1035 1036 // Parse clauses with types 1037 while (succeeded(parser.parseOptionalLParen()) && 1038 (succeeded(parser.parseOptionalKeyword("filter")) || 1039 succeeded(parser.parseOptionalKeyword("catch")))) { 1040 OpAsmParser::UnresolvedOperand operand; 1041 Type ty; 1042 if (parser.parseOperand(operand) || parser.parseColon() || 1043 parser.parseType(ty) || 1044 parser.resolveOperand(operand, ty, result.operands) || 1045 parser.parseRParen()) 1046 return failure(); 1047 } 1048 1049 Type type; 1050 if (parser.parseColon() || parser.parseType(type)) 1051 return failure(); 1052 1053 result.addTypes(type); 1054 return success(); 1055 } 1056 1057 //===----------------------------------------------------------------------===// 1058 // Verifying/Printing/parsing for LLVM::CallOp. 1059 //===----------------------------------------------------------------------===// 1060 1061 LogicalResult CallOp::verify() { 1062 if (getNumResults() > 1) 1063 return emitOpError("must have 0 or 1 result"); 1064 1065 // Type for the callee, we'll get it differently depending if it is a direct 1066 // or indirect call. 1067 Type fnType; 1068 1069 bool isIndirect = false; 1070 1071 // If this is an indirect call, the callee attribute is missing. 1072 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1073 if (!calleeName) { 1074 isIndirect = true; 1075 if (!getNumOperands()) 1076 return emitOpError( 1077 "must have either a `callee` attribute or at least an operand"); 1078 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 1079 if (!ptrType) 1080 return emitOpError("indirect call expects a pointer as callee: ") 1081 << ptrType; 1082 fnType = ptrType.getElementType(); 1083 } else { 1084 Operation *callee = 1085 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1086 if (!callee) 1087 return emitOpError() 1088 << "'" << calleeName.getValue() 1089 << "' does not reference a symbol in the current scope"; 1090 auto fn = dyn_cast<LLVMFuncOp>(callee); 1091 if (!fn) 1092 return emitOpError() << "'" << calleeName.getValue() 1093 << "' does not reference a valid LLVM function"; 1094 1095 fnType = fn.getFunctionType(); 1096 } 1097 1098 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1099 if (!funcType) 1100 return emitOpError("callee does not have a functional type: ") << fnType; 1101 1102 // Verify that the operand and result types match the callee. 1103 1104 if (!funcType.isVarArg() && 1105 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1106 return emitOpError() << "incorrect number of operands (" 1107 << (getNumOperands() - isIndirect) 1108 << ") for callee (expecting: " 1109 << funcType.getNumParams() << ")"; 1110 1111 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1112 return emitOpError() << "incorrect number of operands (" 1113 << (getNumOperands() - isIndirect) 1114 << ") for varargs callee (expecting at least: " 1115 << funcType.getNumParams() << ")"; 1116 1117 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1118 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1119 return emitOpError() << "operand type mismatch for operand " << i << ": " 1120 << getOperand(i + isIndirect).getType() 1121 << " != " << funcType.getParamType(i); 1122 1123 if (getNumResults() == 0 && 1124 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1125 return emitOpError() << "expected function call to produce a value"; 1126 1127 if (getNumResults() != 0 && 1128 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1129 return emitOpError() 1130 << "calling function with void result must not produce values"; 1131 1132 if (getNumResults() > 1) 1133 return emitOpError() 1134 << "expected LLVM function call to produce 0 or 1 result"; 1135 1136 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1137 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1138 << " != " << funcType.getReturnType(); 1139 1140 return success(); 1141 } 1142 1143 void CallOp::print(OpAsmPrinter &p) { 1144 auto callee = getCallee(); 1145 bool isDirect = callee.hasValue(); 1146 1147 // Print the direct callee if present as a function attribute, or an indirect 1148 // callee (first operand) otherwise. 1149 p << ' '; 1150 if (isDirect) 1151 p.printSymbolName(callee.getValue()); 1152 else 1153 p << getOperand(0); 1154 1155 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1156 p << '(' << args << ')'; 1157 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1158 1159 // Reconstruct the function MLIR function type from operand and result types. 1160 p << " : "; 1161 p.printFunctionalType(args.getTypes(), getResultTypes()); 1162 } 1163 1164 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1165 // attribute-dict? `:` function-type 1166 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1167 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1168 Type type; 1169 SymbolRefAttr funcAttr; 1170 SMLoc trailingTypeLoc; 1171 1172 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1173 // case of an indirect call, there will be 1 operand before `(`. In case of a 1174 // direct call, there will be no operands and the parser will stop at the 1175 // function identifier without complaining. 1176 if (parser.parseOperandList(operands)) 1177 return failure(); 1178 bool isDirect = operands.empty(); 1179 1180 // Optionally parse a function identifier. 1181 if (isDirect) 1182 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1183 return failure(); 1184 1185 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1186 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1187 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1188 return failure(); 1189 1190 auto funcType = type.dyn_cast<FunctionType>(); 1191 if (!funcType) 1192 return parser.emitError(trailingTypeLoc, "expected function type"); 1193 if (funcType.getNumResults() > 1) 1194 return parser.emitError(trailingTypeLoc, 1195 "expected function with 0 or 1 result"); 1196 if (isDirect) { 1197 // Make sure types match. 1198 if (parser.resolveOperands(operands, funcType.getInputs(), 1199 parser.getNameLoc(), result.operands)) 1200 return failure(); 1201 if (funcType.getNumResults() != 0 && 1202 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1203 result.addTypes(funcType.getResults()); 1204 } else { 1205 Builder &builder = parser.getBuilder(); 1206 Type llvmResultType; 1207 if (funcType.getNumResults() == 0) { 1208 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1209 } else { 1210 llvmResultType = funcType.getResult(0); 1211 if (!isCompatibleType(llvmResultType)) 1212 return parser.emitError(trailingTypeLoc, 1213 "expected result to have LLVM type"); 1214 } 1215 1216 SmallVector<Type, 8> argTypes; 1217 argTypes.reserve(funcType.getNumInputs()); 1218 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1219 auto argType = funcType.getInput(i); 1220 if (!isCompatibleType(argType)) 1221 return parser.emitError(trailingTypeLoc, 1222 "expected LLVM types as inputs"); 1223 argTypes.push_back(argType); 1224 } 1225 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1226 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1227 1228 auto funcArguments = 1229 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1230 1231 // Make sure that the first operand (indirect callee) matches the wrapped 1232 // LLVM IR function type, and that the types of the other call operands 1233 // match the types of the function arguments. 1234 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1235 parser.resolveOperands(funcArguments, funcType.getInputs(), 1236 parser.getNameLoc(), result.operands)) 1237 return failure(); 1238 1239 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1240 result.addTypes(llvmResultType); 1241 } 1242 1243 return success(); 1244 } 1245 1246 //===----------------------------------------------------------------------===// 1247 // Printing/parsing for LLVM::ExtractElementOp. 1248 //===----------------------------------------------------------------------===// 1249 // Expects vector to be of wrapped LLVM vector type and position to be of 1250 // wrapped LLVM i32 type. 1251 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1252 Value vector, Value position, 1253 ArrayRef<NamedAttribute> attrs) { 1254 auto vectorType = vector.getType(); 1255 auto llvmType = LLVM::getVectorElementType(vectorType); 1256 build(b, result, llvmType, vector, position); 1257 result.addAttributes(attrs); 1258 } 1259 1260 void ExtractElementOp::print(OpAsmPrinter &p) { 1261 p << ' ' << getVector() << "[" << getPosition() << " : " 1262 << getPosition().getType() << "]"; 1263 p.printOptionalAttrDict((*this)->getAttrs()); 1264 p << " : " << getVector().getType(); 1265 } 1266 1267 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1268 // attribute-dict? `:` type 1269 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1270 OperationState &result) { 1271 SMLoc loc; 1272 OpAsmParser::UnresolvedOperand vector, position; 1273 Type type, positionType; 1274 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1275 parser.parseLSquare() || parser.parseOperand(position) || 1276 parser.parseColonType(positionType) || parser.parseRSquare() || 1277 parser.parseOptionalAttrDict(result.attributes) || 1278 parser.parseColonType(type) || 1279 parser.resolveOperand(vector, type, result.operands) || 1280 parser.resolveOperand(position, positionType, result.operands)) 1281 return failure(); 1282 if (!LLVM::isCompatibleVectorType(type)) 1283 return parser.emitError( 1284 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1285 result.addTypes(LLVM::getVectorElementType(type)); 1286 return success(); 1287 } 1288 1289 LogicalResult ExtractElementOp::verify() { 1290 Type vectorType = getVector().getType(); 1291 if (!LLVM::isCompatibleVectorType(vectorType)) 1292 return emitOpError("expected LLVM dialect-compatible vector type for " 1293 "operand #1, got") 1294 << vectorType; 1295 Type valueType = LLVM::getVectorElementType(vectorType); 1296 if (valueType != getRes().getType()) 1297 return emitOpError() << "Type mismatch: extracting from " << vectorType 1298 << " should produce " << valueType 1299 << " but this op returns " << getRes().getType(); 1300 return success(); 1301 } 1302 1303 //===----------------------------------------------------------------------===// 1304 // Printing/parsing for LLVM::ExtractValueOp. 1305 //===----------------------------------------------------------------------===// 1306 1307 void ExtractValueOp::print(OpAsmPrinter &p) { 1308 p << ' ' << getContainer() << getPosition(); 1309 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1310 p << " : " << getContainer().getType(); 1311 } 1312 1313 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1314 // `containerType`. Position is an integer array attribute where each value 1315 // is a zero-based position of the element in the aggregate type. Return the 1316 // resulting type wrapped in MLIR, or nullptr on error. 1317 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1318 Type containerType, 1319 ArrayAttr positionAttr, 1320 SMLoc attributeLoc, 1321 SMLoc typeLoc) { 1322 Type llvmType = containerType; 1323 if (!isCompatibleType(containerType)) 1324 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1325 1326 // Infer the element type from the structure type: iteratively step inside the 1327 // type by taking the element type, indexed by the position attribute for 1328 // structures. Check the position index before accessing, it is supposed to 1329 // be in bounds. 1330 for (Attribute subAttr : positionAttr) { 1331 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1332 if (!positionElementAttr) 1333 return parser.emitError(attributeLoc, 1334 "expected an array of integer literals"), 1335 nullptr; 1336 int position = positionElementAttr.getInt(); 1337 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1338 if (position < 0 || 1339 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1340 return parser.emitError(attributeLoc, "position out of bounds"), 1341 nullptr; 1342 llvmType = arrayType.getElementType(); 1343 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1344 if (position < 0 || 1345 static_cast<unsigned>(position) >= structType.getBody().size()) 1346 return parser.emitError(attributeLoc, "position out of bounds"), 1347 nullptr; 1348 llvmType = structType.getBody()[position]; 1349 } else { 1350 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1351 nullptr; 1352 } 1353 } 1354 return llvmType; 1355 } 1356 1357 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1358 // `containerType`. Returns null on failure. 1359 static Type getInsertExtractValueElementType(Type containerType, 1360 ArrayAttr positionAttr, 1361 Operation *op) { 1362 Type llvmType = containerType; 1363 if (!isCompatibleType(containerType)) { 1364 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1365 return {}; 1366 } 1367 1368 // Infer the element type from the structure type: iteratively step inside the 1369 // type by taking the element type, indexed by the position attribute for 1370 // structures. Check the position index before accessing, it is supposed to 1371 // be in bounds. 1372 for (Attribute subAttr : positionAttr) { 1373 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1374 if (!positionElementAttr) { 1375 op->emitOpError("expected an array of integer literals, got: ") 1376 << subAttr; 1377 return {}; 1378 } 1379 int position = positionElementAttr.getInt(); 1380 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1381 if (position < 0 || 1382 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1383 op->emitOpError("position out of bounds: ") << position; 1384 return {}; 1385 } 1386 llvmType = arrayType.getElementType(); 1387 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1388 if (position < 0 || 1389 static_cast<unsigned>(position) >= structType.getBody().size()) { 1390 op->emitOpError("position out of bounds") << position; 1391 return {}; 1392 } 1393 llvmType = structType.getBody()[position]; 1394 } else { 1395 op->emitOpError("expected LLVM IR structure/array type, got: ") 1396 << llvmType; 1397 return {}; 1398 } 1399 } 1400 return llvmType; 1401 } 1402 1403 // <operation> ::= `llvm.extractvalue` ssa-use 1404 // `[` integer-literal (`,` integer-literal)* `]` 1405 // attribute-dict? `:` type 1406 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1407 OpAsmParser::UnresolvedOperand container; 1408 Type containerType; 1409 ArrayAttr positionAttr; 1410 SMLoc attributeLoc, trailingTypeLoc; 1411 1412 if (parser.parseOperand(container) || 1413 parser.getCurrentLocation(&attributeLoc) || 1414 parser.parseAttribute(positionAttr, "position", result.attributes) || 1415 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1416 parser.getCurrentLocation(&trailingTypeLoc) || 1417 parser.parseType(containerType) || 1418 parser.resolveOperand(container, containerType, result.operands)) 1419 return failure(); 1420 1421 auto elementType = getInsertExtractValueElementType( 1422 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1423 if (!elementType) 1424 return failure(); 1425 1426 result.addTypes(elementType); 1427 return success(); 1428 } 1429 1430 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1431 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1432 OpFoldResult result = {}; 1433 while (insertValueOp) { 1434 if (getPosition() == insertValueOp.getPosition()) 1435 return insertValueOp.getValue(); 1436 unsigned min = 1437 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1438 // If one is fully prefix of the other, stop propagating back as it will 1439 // miss dependencies. For instance, %3 should not fold to %f0 in the 1440 // following example: 1441 // ``` 1442 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1443 // !llvm.array<4 x !llvm.array<4xf32>> 1444 // %2 = llvm.insertvalue %arr, %1[0] : 1445 // !llvm.array<4 x !llvm.array<4xf32>> 1446 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1447 // ``` 1448 if (getPosition().getValue().take_front(min) == 1449 insertValueOp.getPosition().getValue().take_front(min)) 1450 return result; 1451 1452 // If neither a prefix, nor the exact position, we can extract out of the 1453 // value being inserted into. Moreover, we can try again if that operand 1454 // is itself an insertvalue expression. 1455 getContainerMutable().assign(insertValueOp.getContainer()); 1456 result = getResult(); 1457 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1458 } 1459 return result; 1460 } 1461 1462 LogicalResult ExtractValueOp::verify() { 1463 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1464 getPositionAttr(), *this); 1465 if (!valueType) 1466 return failure(); 1467 1468 if (getRes().getType() != valueType) 1469 return emitOpError() << "Type mismatch: extracting from " 1470 << getContainer().getType() << " should produce " 1471 << valueType << " but this op returns " 1472 << getRes().getType(); 1473 return success(); 1474 } 1475 1476 //===----------------------------------------------------------------------===// 1477 // Printing/parsing for LLVM::InsertElementOp. 1478 //===----------------------------------------------------------------------===// 1479 1480 void InsertElementOp::print(OpAsmPrinter &p) { 1481 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1482 << getPosition().getType() << "]"; 1483 p.printOptionalAttrDict((*this)->getAttrs()); 1484 p << " : " << getVector().getType(); 1485 } 1486 1487 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1488 // attribute-dict? `:` type 1489 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1490 OperationState &result) { 1491 SMLoc loc; 1492 OpAsmParser::UnresolvedOperand vector, value, position; 1493 Type vectorType, positionType; 1494 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1495 parser.parseComma() || parser.parseOperand(vector) || 1496 parser.parseLSquare() || parser.parseOperand(position) || 1497 parser.parseColonType(positionType) || parser.parseRSquare() || 1498 parser.parseOptionalAttrDict(result.attributes) || 1499 parser.parseColonType(vectorType)) 1500 return failure(); 1501 1502 if (!LLVM::isCompatibleVectorType(vectorType)) 1503 return parser.emitError( 1504 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1505 Type valueType = LLVM::getVectorElementType(vectorType); 1506 if (!valueType) 1507 return failure(); 1508 1509 if (parser.resolveOperand(vector, vectorType, result.operands) || 1510 parser.resolveOperand(value, valueType, result.operands) || 1511 parser.resolveOperand(position, positionType, result.operands)) 1512 return failure(); 1513 1514 result.addTypes(vectorType); 1515 return success(); 1516 } 1517 1518 LogicalResult InsertElementOp::verify() { 1519 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1520 if (valueType != getValue().getType()) 1521 return emitOpError() << "Type mismatch: cannot insert " 1522 << getValue().getType() << " into " 1523 << getVector().getType(); 1524 return success(); 1525 } 1526 1527 //===----------------------------------------------------------------------===// 1528 // Printing/parsing for LLVM::InsertValueOp. 1529 //===----------------------------------------------------------------------===// 1530 1531 void InsertValueOp::print(OpAsmPrinter &p) { 1532 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1533 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1534 p << " : " << getContainer().getType(); 1535 } 1536 1537 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1538 // `[` integer-literal (`,` integer-literal)* `]` 1539 // attribute-dict? `:` type 1540 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1541 OpAsmParser::UnresolvedOperand container, value; 1542 Type containerType; 1543 ArrayAttr positionAttr; 1544 SMLoc attributeLoc, trailingTypeLoc; 1545 1546 if (parser.parseOperand(value) || parser.parseComma() || 1547 parser.parseOperand(container) || 1548 parser.getCurrentLocation(&attributeLoc) || 1549 parser.parseAttribute(positionAttr, "position", result.attributes) || 1550 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1551 parser.getCurrentLocation(&trailingTypeLoc) || 1552 parser.parseType(containerType)) 1553 return failure(); 1554 1555 auto valueType = getInsertExtractValueElementType( 1556 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1557 if (!valueType) 1558 return failure(); 1559 1560 if (parser.resolveOperand(container, containerType, result.operands) || 1561 parser.resolveOperand(value, valueType, result.operands)) 1562 return failure(); 1563 1564 result.addTypes(containerType); 1565 return success(); 1566 } 1567 1568 LogicalResult InsertValueOp::verify() { 1569 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1570 getPositionAttr(), *this); 1571 if (!valueType) 1572 return failure(); 1573 1574 if (getValue().getType() != valueType) 1575 return emitOpError() << "Type mismatch: cannot insert " 1576 << getValue().getType() << " into " 1577 << getContainer().getType(); 1578 1579 return success(); 1580 } 1581 1582 //===----------------------------------------------------------------------===// 1583 // Printing, parsing and verification for LLVM::ReturnOp. 1584 //===----------------------------------------------------------------------===// 1585 1586 LogicalResult ReturnOp::verify() { 1587 if (getNumOperands() > 1) 1588 return emitOpError("expected at most 1 operand"); 1589 1590 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1591 Type expectedType = parent.getFunctionType().getReturnType(); 1592 if (expectedType.isa<LLVMVoidType>()) { 1593 if (getNumOperands() == 0) 1594 return success(); 1595 InFlightDiagnostic diag = emitOpError("expected no operands"); 1596 diag.attachNote(parent->getLoc()) << "when returning from function"; 1597 return diag; 1598 } 1599 if (getNumOperands() == 0) { 1600 if (expectedType.isa<LLVMVoidType>()) 1601 return success(); 1602 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1603 diag.attachNote(parent->getLoc()) << "when returning from function"; 1604 return diag; 1605 } 1606 if (expectedType != getOperand(0).getType()) { 1607 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1608 diag.attachNote(parent->getLoc()) << "when returning from function"; 1609 return diag; 1610 } 1611 } 1612 return success(); 1613 } 1614 1615 //===----------------------------------------------------------------------===// 1616 // ResumeOp 1617 //===----------------------------------------------------------------------===// 1618 1619 LogicalResult ResumeOp::verify() { 1620 if (!getValue().getDefiningOp<LandingpadOp>()) 1621 return emitOpError("expects landingpad value as operand"); 1622 // No check for personality of function - landingpad op verifies it. 1623 return success(); 1624 } 1625 1626 //===----------------------------------------------------------------------===// 1627 // Verifier for LLVM::AddressOfOp. 1628 //===----------------------------------------------------------------------===// 1629 1630 template <typename OpTy> 1631 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1632 Operation *module = parent; 1633 while (module && !satisfiesLLVMModule(module)) 1634 module = module->getParentOp(); 1635 assert(module && "unexpected operation outside of a module"); 1636 return dyn_cast_or_null<OpTy>( 1637 mlir::SymbolTable::lookupSymbolIn(module, name)); 1638 } 1639 1640 GlobalOp AddressOfOp::getGlobal() { 1641 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1642 getGlobalName()); 1643 } 1644 1645 LLVMFuncOp AddressOfOp::getFunction() { 1646 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1647 getGlobalName()); 1648 } 1649 1650 LogicalResult AddressOfOp::verify() { 1651 auto global = getGlobal(); 1652 auto function = getFunction(); 1653 if (!global && !function) 1654 return emitOpError( 1655 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1656 1657 LLVMPointerType type = getType(); 1658 if (global && global.getAddrSpace() != type.getAddressSpace()) 1659 return emitOpError("pointer address space must match address space of the " 1660 "referenced global"); 1661 1662 if (type.isOpaque()) 1663 return success(); 1664 1665 if (global && type.getElementType() != global.getType()) 1666 return emitOpError( 1667 "the type must be a pointer to the type of the referenced global"); 1668 1669 if (function && type.getElementType() != function.getFunctionType()) 1670 return emitOpError( 1671 "the type must be a pointer to the type of the referenced function"); 1672 1673 return success(); 1674 } 1675 1676 //===----------------------------------------------------------------------===// 1677 // Builder, printer and verifier for LLVM::GlobalOp. 1678 //===----------------------------------------------------------------------===// 1679 1680 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1681 bool isConstant, Linkage linkage, StringRef name, 1682 Attribute value, uint64_t alignment, unsigned addrSpace, 1683 bool dsoLocal, bool threadLocal, 1684 ArrayRef<NamedAttribute> attrs) { 1685 result.addAttribute(getSymNameAttrName(result.name), 1686 builder.getStringAttr(name)); 1687 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 1688 if (isConstant) 1689 result.addAttribute(getConstantAttrName(result.name), 1690 builder.getUnitAttr()); 1691 if (value) 1692 result.addAttribute(getValueAttrName(result.name), value); 1693 if (dsoLocal) 1694 result.addAttribute(getDsoLocalAttrName(result.name), 1695 builder.getUnitAttr()); 1696 if (threadLocal) 1697 result.addAttribute(getThreadLocal_AttrName(result.name), 1698 builder.getUnitAttr()); 1699 1700 // Only add an alignment attribute if the "alignment" input 1701 // is different from 0. The value must also be a power of two, but 1702 // this is tested in GlobalOp::verify, not here. 1703 if (alignment != 0) 1704 result.addAttribute(getAlignmentAttrName(result.name), 1705 builder.getI64IntegerAttr(alignment)); 1706 1707 result.addAttribute(getLinkageAttrName(result.name), 1708 LinkageAttr::get(builder.getContext(), linkage)); 1709 if (addrSpace != 0) 1710 result.addAttribute(getAddrSpaceAttrName(result.name), 1711 builder.getI32IntegerAttr(addrSpace)); 1712 result.attributes.append(attrs.begin(), attrs.end()); 1713 result.addRegion(); 1714 } 1715 1716 void GlobalOp::print(OpAsmPrinter &p) { 1717 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1718 if (auto unnamedAddr = getUnnamedAddr()) { 1719 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1720 if (!str.empty()) 1721 p << str << ' '; 1722 } 1723 if (getThreadLocal_()) 1724 p << "thread_local "; 1725 if (getConstant()) 1726 p << "constant "; 1727 p.printSymbolName(getSymName()); 1728 p << '('; 1729 if (auto value = getValueOrNull()) 1730 p.printAttribute(value); 1731 p << ')'; 1732 // Note that the alignment attribute is printed using the 1733 // default syntax here, even though it is an inherent attribute 1734 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1735 p.printOptionalAttrDict( 1736 (*this)->getAttrs(), 1737 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), 1738 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), 1739 getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); 1740 1741 // Print the trailing type unless it's a string global. 1742 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1743 return; 1744 p << " : " << getType(); 1745 1746 Region &initializer = getInitializerRegion(); 1747 if (!initializer.empty()) { 1748 p << ' '; 1749 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1750 } 1751 } 1752 1753 // Parses one of the keywords provided in the list `keywords` and returns the 1754 // position of the parsed keyword in the list. If none of the keywords from the 1755 // list is parsed, returns -1. 1756 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1757 ArrayRef<StringRef> keywords) { 1758 for (const auto &en : llvm::enumerate(keywords)) { 1759 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1760 return en.index(); 1761 } 1762 return -1; 1763 } 1764 1765 namespace { 1766 template <typename Ty> 1767 struct EnumTraits {}; 1768 1769 #define REGISTER_ENUM_TYPE(Ty) \ 1770 template <> \ 1771 struct EnumTraits<Ty> { \ 1772 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1773 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1774 } 1775 1776 REGISTER_ENUM_TYPE(Linkage); 1777 REGISTER_ENUM_TYPE(UnnamedAddr); 1778 } // namespace 1779 1780 /// Parse an enum from the keyword, or default to the provided default value. 1781 /// The return type is the enum type by default, unless overriden with the 1782 /// second template argument. 1783 template <typename EnumTy, typename RetTy = EnumTy> 1784 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1785 OperationState &result, 1786 EnumTy defaultValue) { 1787 SmallVector<StringRef, 10> names; 1788 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1789 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1790 1791 int index = parseOptionalKeywordAlternative(parser, names); 1792 if (index == -1) 1793 return static_cast<RetTy>(defaultValue); 1794 return static_cast<RetTy>(index); 1795 } 1796 1797 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1798 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1799 // align ::= `align` `=` UINT64 1800 // 1801 // The type can be omitted for string attributes, in which case it will be 1802 // inferred from the value of the string as [strlen(value) x i8]. 1803 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1804 MLIRContext *ctx = parser.getContext(); 1805 // Parse optional linkage, default to External. 1806 result.addAttribute(getLinkageAttrName(result.name), 1807 LLVM::LinkageAttr::get( 1808 ctx, parseOptionalLLVMKeyword<Linkage>( 1809 parser, result, LLVM::Linkage::External))); 1810 1811 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 1812 result.addAttribute(getThreadLocal_AttrName(result.name), 1813 parser.getBuilder().getUnitAttr()); 1814 1815 // Parse optional UnnamedAddr, default to None. 1816 result.addAttribute(getUnnamedAddrAttrName(result.name), 1817 parser.getBuilder().getI64IntegerAttr( 1818 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1819 parser, result, LLVM::UnnamedAddr::None))); 1820 1821 if (succeeded(parser.parseOptionalKeyword("constant"))) 1822 result.addAttribute(getConstantAttrName(result.name), 1823 parser.getBuilder().getUnitAttr()); 1824 1825 StringAttr name; 1826 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 1827 result.attributes) || 1828 parser.parseLParen()) 1829 return failure(); 1830 1831 Attribute value; 1832 if (parser.parseOptionalRParen()) { 1833 if (parser.parseAttribute(value, getValueAttrName(result.name), 1834 result.attributes) || 1835 parser.parseRParen()) 1836 return failure(); 1837 } 1838 1839 SmallVector<Type, 1> types; 1840 if (parser.parseOptionalAttrDict(result.attributes) || 1841 parser.parseOptionalColonTypeList(types)) 1842 return failure(); 1843 1844 if (types.size() > 1) 1845 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1846 1847 Region &initRegion = *result.addRegion(); 1848 if (types.empty()) { 1849 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1850 MLIRContext *context = parser.getContext(); 1851 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1852 strAttr.getValue().size()); 1853 types.push_back(arrayType); 1854 } else { 1855 return parser.emitError(parser.getNameLoc(), 1856 "type can only be omitted for string globals"); 1857 } 1858 } else { 1859 OptionalParseResult parseResult = 1860 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1861 /*argTypes=*/{}); 1862 if (parseResult.hasValue() && failed(*parseResult)) 1863 return failure(); 1864 } 1865 1866 result.addAttribute(getGlobalTypeAttrName(result.name), 1867 TypeAttr::get(types[0])); 1868 return success(); 1869 } 1870 1871 static bool isZeroAttribute(Attribute value) { 1872 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1873 return intValue.getValue().isNullValue(); 1874 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1875 return fpValue.getValue().isZero(); 1876 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1877 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1878 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1879 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1880 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1881 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1882 return false; 1883 } 1884 1885 LogicalResult GlobalOp::verify() { 1886 if (!LLVMPointerType::isValidElementType(getType())) 1887 return emitOpError( 1888 "expects type to be a valid element type for an LLVM pointer"); 1889 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1890 return emitOpError("must appear at the module level"); 1891 1892 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1893 auto type = getType().dyn_cast<LLVMArrayType>(); 1894 IntegerType elementType = 1895 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1896 if (!elementType || elementType.getWidth() != 8 || 1897 type.getNumElements() != strAttr.getValue().size()) 1898 return emitOpError( 1899 "requires an i8 array type of the length equal to that of the string " 1900 "attribute"); 1901 } 1902 1903 if (getLinkage() == Linkage::Common) { 1904 if (Attribute value = getValueOrNull()) { 1905 if (!isZeroAttribute(value)) { 1906 return emitOpError() 1907 << "expected zero value for '" 1908 << stringifyLinkage(Linkage::Common) << "' linkage"; 1909 } 1910 } 1911 } 1912 1913 if (getLinkage() == Linkage::Appending) { 1914 if (!getType().isa<LLVMArrayType>()) { 1915 return emitOpError() << "expected array type for '" 1916 << stringifyLinkage(Linkage::Appending) 1917 << "' linkage"; 1918 } 1919 } 1920 1921 Optional<uint64_t> alignAttr = getAlignment(); 1922 if (alignAttr.hasValue()) { 1923 uint64_t value = alignAttr.getValue(); 1924 if (!llvm::isPowerOf2_64(value)) 1925 return emitError() << "alignment attribute is not a power of 2"; 1926 } 1927 1928 return success(); 1929 } 1930 1931 LogicalResult GlobalOp::verifyRegions() { 1932 if (Block *b = getInitializerBlock()) { 1933 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1934 if (ret.operand_type_begin() == ret.operand_type_end()) 1935 return emitOpError("initializer region cannot return void"); 1936 if (*ret.operand_type_begin() != getType()) 1937 return emitOpError("initializer region type ") 1938 << *ret.operand_type_begin() << " does not match global type " 1939 << getType(); 1940 1941 for (Operation &op : *b) { 1942 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 1943 if (!iface || !iface.hasNoEffect()) 1944 return op.emitError() 1945 << "ops with side effects not allowed in global initializers"; 1946 } 1947 1948 if (getValueOrNull()) 1949 return emitOpError("cannot have both initializer value and region"); 1950 } 1951 1952 return success(); 1953 } 1954 1955 //===----------------------------------------------------------------------===// 1956 // LLVM::GlobalCtorsOp 1957 //===----------------------------------------------------------------------===// 1958 1959 LogicalResult 1960 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1961 for (Attribute ctor : getCtors()) { 1962 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 1963 symbolTable))) 1964 return failure(); 1965 } 1966 return success(); 1967 } 1968 1969 LogicalResult GlobalCtorsOp::verify() { 1970 if (getCtors().size() != getPriorities().size()) 1971 return emitError( 1972 "mismatch between the number of ctors and the number of priorities"); 1973 return success(); 1974 } 1975 1976 //===----------------------------------------------------------------------===// 1977 // LLVM::GlobalDtorsOp 1978 //===----------------------------------------------------------------------===// 1979 1980 LogicalResult 1981 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1982 for (Attribute dtor : getDtors()) { 1983 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 1984 symbolTable))) 1985 return failure(); 1986 } 1987 return success(); 1988 } 1989 1990 LogicalResult GlobalDtorsOp::verify() { 1991 if (getDtors().size() != getPriorities().size()) 1992 return emitError( 1993 "mismatch between the number of dtors and the number of priorities"); 1994 return success(); 1995 } 1996 1997 //===----------------------------------------------------------------------===// 1998 // Printing/parsing for LLVM::ShuffleVectorOp. 1999 //===----------------------------------------------------------------------===// 2000 // Expects vector to be of wrapped LLVM vector type and position to be of 2001 // wrapped LLVM i32 type. 2002 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 2003 Value v1, Value v2, ArrayAttr mask, 2004 ArrayRef<NamedAttribute> attrs) { 2005 auto containerType = v1.getType(); 2006 auto vType = LLVM::getVectorType( 2007 LLVM::getVectorElementType(containerType), mask.size(), 2008 containerType.cast<VectorType>().isScalable()); 2009 build(b, result, vType, v1, v2, mask); 2010 result.addAttributes(attrs); 2011 } 2012 2013 void ShuffleVectorOp::print(OpAsmPrinter &p) { 2014 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 2015 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 2016 p << " : " << getV1().getType() << ", " << getV2().getType(); 2017 } 2018 2019 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 2020 // `[` integer-literal (`,` integer-literal)* `]` 2021 // attribute-dict? `:` type 2022 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 2023 OperationState &result) { 2024 SMLoc loc; 2025 OpAsmParser::UnresolvedOperand v1, v2; 2026 ArrayAttr maskAttr; 2027 Type typeV1, typeV2; 2028 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 2029 parser.parseComma() || parser.parseOperand(v2) || 2030 parser.parseAttribute(maskAttr, "mask", result.attributes) || 2031 parser.parseOptionalAttrDict(result.attributes) || 2032 parser.parseColonType(typeV1) || parser.parseComma() || 2033 parser.parseType(typeV2) || 2034 parser.resolveOperand(v1, typeV1, result.operands) || 2035 parser.resolveOperand(v2, typeV2, result.operands)) 2036 return failure(); 2037 if (!LLVM::isCompatibleVectorType(typeV1)) 2038 return parser.emitError( 2039 loc, "expected LLVM IR dialect vector type for operand #1"); 2040 auto vType = 2041 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 2042 typeV1.cast<VectorType>().isScalable()); 2043 result.addTypes(vType); 2044 return success(); 2045 } 2046 2047 LogicalResult ShuffleVectorOp::verify() { 2048 Type type1 = getV1().getType(); 2049 Type type2 = getV2().getType(); 2050 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 2051 return emitOpError("expected matching LLVM IR Dialect element types"); 2052 if (LLVM::isScalableVectorType(type1)) 2053 if (llvm::any_of(getMask(), [](Attribute attr) { 2054 return attr.cast<IntegerAttr>().getInt() != 0; 2055 })) 2056 return emitOpError("expected a splat operation for scalable vectors"); 2057 return success(); 2058 } 2059 2060 //===----------------------------------------------------------------------===// 2061 // Implementations for LLVM::LLVMFuncOp. 2062 //===----------------------------------------------------------------------===// 2063 2064 // Add the entry block to the function. 2065 Block *LLVMFuncOp::addEntryBlock() { 2066 assert(empty() && "function already has an entry block"); 2067 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 2068 2069 auto *entry = new Block; 2070 push_back(entry); 2071 2072 // FIXME: Allow passing in proper locations for the entry arguments. 2073 LLVMFunctionType type = getFunctionType(); 2074 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2075 entry->addArgument(type.getParamType(i), getLoc()); 2076 return entry; 2077 } 2078 2079 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2080 StringRef name, Type type, LLVM::Linkage linkage, 2081 bool dsoLocal, ArrayRef<NamedAttribute> attrs, 2082 ArrayRef<DictionaryAttr> argAttrs) { 2083 result.addRegion(); 2084 result.addAttribute(SymbolTable::getSymbolAttrName(), 2085 builder.getStringAttr(name)); 2086 result.addAttribute(getFunctionTypeAttrName(result.name), 2087 TypeAttr::get(type)); 2088 result.addAttribute(getLinkageAttrName(result.name), 2089 LinkageAttr::get(builder.getContext(), linkage)); 2090 result.attributes.append(attrs.begin(), attrs.end()); 2091 if (dsoLocal) 2092 result.addAttribute("dso_local", builder.getUnitAttr()); 2093 if (argAttrs.empty()) 2094 return; 2095 2096 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 2097 "expected as many argument attribute lists as arguments"); 2098 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 2099 /*resultAttrs=*/llvm::None); 2100 } 2101 2102 // Builds an LLVM function type from the given lists of input and output types. 2103 // Returns a null type if any of the types provided are non-LLVM types, or if 2104 // there is more than one output type. 2105 static Type 2106 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2107 ArrayRef<Type> outputs, 2108 function_interface_impl::VariadicFlag variadicFlag) { 2109 Builder &b = parser.getBuilder(); 2110 if (outputs.size() > 1) { 2111 parser.emitError(loc, "failed to construct function type: expected zero or " 2112 "one function result"); 2113 return {}; 2114 } 2115 2116 // Convert inputs to LLVM types, exit early on error. 2117 SmallVector<Type, 4> llvmInputs; 2118 for (auto t : inputs) { 2119 if (!isCompatibleType(t)) { 2120 parser.emitError(loc, "failed to construct function type: expected LLVM " 2121 "type for function arguments"); 2122 return {}; 2123 } 2124 llvmInputs.push_back(t); 2125 } 2126 2127 // No output is denoted as "void" in LLVM type system. 2128 Type llvmOutput = 2129 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2130 if (!isCompatibleType(llvmOutput)) { 2131 parser.emitError(loc, "failed to construct function type: expected LLVM " 2132 "type for function results") 2133 << llvmOutput; 2134 return {}; 2135 } 2136 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2137 variadicFlag.isVariadic()); 2138 } 2139 2140 // Parses an LLVM function. 2141 // 2142 // operation ::= `llvm.func` linkage? function-signature function-attributes? 2143 // function-body 2144 // 2145 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2146 // Default to external linkage if no keyword is provided. 2147 result.addAttribute( 2148 getLinkageAttrName(result.name), 2149 LinkageAttr::get(parser.getContext(), 2150 parseOptionalLLVMKeyword<Linkage>( 2151 parser, result, LLVM::Linkage::External))); 2152 2153 StringAttr nameAttr; 2154 SmallVector<OpAsmParser::UnresolvedOperand> entryArgs; 2155 SmallVector<NamedAttrList> argAttrs; 2156 SmallVector<NamedAttrList> resultAttrs; 2157 SmallVector<Type> argTypes; 2158 SmallVector<Type> resultTypes; 2159 bool isVariadic; 2160 2161 auto signatureLocation = parser.getCurrentLocation(); 2162 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2163 result.attributes) || 2164 function_interface_impl::parseFunctionSignature( 2165 parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, 2166 isVariadic, resultTypes, resultAttrs)) 2167 return failure(); 2168 2169 auto type = 2170 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2171 function_interface_impl::VariadicFlag(isVariadic)); 2172 if (!type) 2173 return failure(); 2174 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2175 TypeAttr::get(type)); 2176 2177 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2178 return failure(); 2179 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2180 argAttrs, resultAttrs); 2181 2182 auto *body = result.addRegion(); 2183 OptionalParseResult parseResult = parser.parseOptionalRegion( 2184 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 2185 return failure(parseResult.hasValue() && failed(*parseResult)); 2186 } 2187 2188 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2189 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2190 // the external linkage since it is the default value. 2191 void LLVMFuncOp::print(OpAsmPrinter &p) { 2192 p << ' '; 2193 if (getLinkage() != LLVM::Linkage::External) 2194 p << stringifyLinkage(getLinkage()) << ' '; 2195 p.printSymbolName(getName()); 2196 2197 LLVMFunctionType fnType = getFunctionType(); 2198 SmallVector<Type, 8> argTypes; 2199 SmallVector<Type, 1> resTypes; 2200 argTypes.reserve(fnType.getNumParams()); 2201 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2202 argTypes.push_back(fnType.getParamType(i)); 2203 2204 Type returnType = fnType.getReturnType(); 2205 if (!returnType.isa<LLVMVoidType>()) 2206 resTypes.push_back(returnType); 2207 2208 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2209 isVarArg(), resTypes); 2210 function_interface_impl::printFunctionAttributes( 2211 p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); 2212 2213 // Print the body if this is not an external function. 2214 Region &body = getBody(); 2215 if (!body.empty()) { 2216 p << ' '; 2217 p.printRegion(body, /*printEntryBlockArgs=*/false, 2218 /*printBlockTerminators=*/true); 2219 } 2220 } 2221 2222 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2223 // - functions don't have 'common' linkage 2224 // - external functions have 'external' or 'extern_weak' linkage; 2225 // - vararg is (currently) only supported for external functions; 2226 LogicalResult LLVMFuncOp::verify() { 2227 if (getLinkage() == LLVM::Linkage::Common) 2228 return emitOpError() << "functions cannot have '" 2229 << stringifyLinkage(LLVM::Linkage::Common) 2230 << "' linkage"; 2231 2232 // Check to see if this function has a void return with a result attribute to 2233 // it. It isn't clear what semantics we would assign to that. 2234 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2235 !getResultAttrs(0).empty()) { 2236 return emitOpError() 2237 << "cannot attach result attributes to functions with a void return"; 2238 } 2239 2240 if (isExternal()) { 2241 if (getLinkage() != LLVM::Linkage::External && 2242 getLinkage() != LLVM::Linkage::ExternWeak) 2243 return emitOpError() << "external functions must have '" 2244 << stringifyLinkage(LLVM::Linkage::External) 2245 << "' or '" 2246 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2247 << "' linkage"; 2248 return success(); 2249 } 2250 2251 if (isVarArg()) 2252 return emitOpError("only external functions can be variadic"); 2253 2254 return success(); 2255 } 2256 2257 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2258 /// - entry block arguments are of LLVM types. 2259 LogicalResult LLVMFuncOp::verifyRegions() { 2260 if (isExternal()) 2261 return success(); 2262 2263 unsigned numArguments = getFunctionType().getNumParams(); 2264 Block &entryBlock = front(); 2265 for (unsigned i = 0; i < numArguments; ++i) { 2266 Type argType = entryBlock.getArgument(i).getType(); 2267 if (!isCompatibleType(argType)) 2268 return emitOpError("entry block argument #") 2269 << i << " is not of LLVM type"; 2270 } 2271 2272 return success(); 2273 } 2274 2275 //===----------------------------------------------------------------------===// 2276 // Verification for LLVM::ConstantOp. 2277 //===----------------------------------------------------------------------===// 2278 2279 LogicalResult LLVM::ConstantOp::verify() { 2280 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2281 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2282 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2283 !arrayType.getElementType().isInteger(8)) { 2284 return emitOpError() << "expected array type of " 2285 << sAttr.getValue().size() 2286 << " i8 elements for the string constant"; 2287 } 2288 return success(); 2289 } 2290 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2291 if (structType.getBody().size() != 2 || 2292 structType.getBody()[0] != structType.getBody()[1]) { 2293 return emitError() << "expected struct type with two elements of the " 2294 "same type, the type of a complex constant"; 2295 } 2296 2297 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2298 if (!arrayAttr || arrayAttr.size() != 2 || 2299 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2300 return emitOpError() << "expected array attribute with two elements, " 2301 "representing a complex constant"; 2302 } 2303 2304 Type elementType = structType.getBody()[0]; 2305 if (!elementType 2306 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2307 return emitError() 2308 << "expected struct element types to be floating point type or " 2309 "integer type"; 2310 } 2311 return success(); 2312 } 2313 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2314 return emitOpError() 2315 << "only supports integer, float, string or elements attributes"; 2316 return success(); 2317 } 2318 2319 // Constant op constant-folds to its value. 2320 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2321 2322 //===----------------------------------------------------------------------===// 2323 // Utility functions for parsing atomic ops 2324 //===----------------------------------------------------------------------===// 2325 2326 // Helper function to parse a keyword into the specified attribute named by 2327 // `attrName`. The keyword must match one of the string values defined by the 2328 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2329 // state. 2330 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2331 StringRef attrName) { 2332 SMLoc loc; 2333 StringRef keyword; 2334 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2335 return failure(); 2336 2337 // Replace the keyword `keyword` with an integer attribute. 2338 auto kind = symbolizeAtomicBinOp(keyword); 2339 if (!kind) { 2340 return parser.emitError(loc) 2341 << "'" << keyword << "' is an incorrect value of the '" << attrName 2342 << "' attribute"; 2343 } 2344 2345 auto value = static_cast<int64_t>(kind.getValue()); 2346 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2347 result.addAttribute(attrName, attr); 2348 2349 return success(); 2350 } 2351 2352 // Helper function to parse a keyword into the specified attribute named by 2353 // `attrName`. The keyword must match one of the string values defined by the 2354 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2355 // state. 2356 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2357 OperationState &result, 2358 StringRef attrName) { 2359 SMLoc loc; 2360 StringRef ordering; 2361 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2362 return failure(); 2363 2364 // Replace the keyword `ordering` with an integer attribute. 2365 auto kind = symbolizeAtomicOrdering(ordering); 2366 if (!kind) { 2367 return parser.emitError(loc) 2368 << "'" << ordering << "' is an incorrect value of the '" << attrName 2369 << "' attribute"; 2370 } 2371 2372 auto value = static_cast<int64_t>(kind.getValue()); 2373 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2374 result.addAttribute(attrName, attr); 2375 2376 return success(); 2377 } 2378 2379 //===----------------------------------------------------------------------===// 2380 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2381 //===----------------------------------------------------------------------===// 2382 2383 void AtomicRMWOp::print(OpAsmPrinter &p) { 2384 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2385 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2386 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2387 p << " : " << getRes().getType(); 2388 } 2389 2390 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2391 // attribute-dict? `:` type 2392 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2393 Type type; 2394 OpAsmParser::UnresolvedOperand ptr, val; 2395 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2396 parser.parseComma() || parser.parseOperand(val) || 2397 parseAtomicOrdering(parser, result, "ordering") || 2398 parser.parseOptionalAttrDict(result.attributes) || 2399 parser.parseColonType(type) || 2400 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2401 result.operands) || 2402 parser.resolveOperand(val, type, result.operands)) 2403 return failure(); 2404 2405 result.addTypes(type); 2406 return success(); 2407 } 2408 2409 LogicalResult AtomicRMWOp::verify() { 2410 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2411 auto valType = getVal().getType(); 2412 if (valType != ptrType.getElementType()) 2413 return emitOpError("expected LLVM IR element type for operand #0 to " 2414 "match type for operand #1"); 2415 auto resType = getRes().getType(); 2416 if (resType != valType) 2417 return emitOpError( 2418 "expected LLVM IR result type to match type for operand #1"); 2419 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2420 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2421 return emitOpError("expected LLVM IR floating point type"); 2422 } else if (getBinOp() == AtomicBinOp::xchg) { 2423 auto intType = valType.dyn_cast<IntegerType>(); 2424 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2425 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2426 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2427 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2428 !valType.isa<Float64Type>()) 2429 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2430 } else { 2431 auto intType = valType.dyn_cast<IntegerType>(); 2432 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2433 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2434 intBitWidth != 64) 2435 return emitOpError("expected LLVM IR integer type"); 2436 } 2437 2438 if (static_cast<unsigned>(getOrdering()) < 2439 static_cast<unsigned>(AtomicOrdering::monotonic)) 2440 return emitOpError() << "expected at least '" 2441 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2442 << "' ordering"; 2443 2444 return success(); 2445 } 2446 2447 //===----------------------------------------------------------------------===// 2448 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2449 //===----------------------------------------------------------------------===// 2450 2451 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2452 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2453 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2454 << stringifyAtomicOrdering(getFailureOrdering()); 2455 p.printOptionalAttrDict((*this)->getAttrs(), 2456 {"success_ordering", "failure_ordering"}); 2457 p << " : " << getVal().getType(); 2458 } 2459 2460 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2461 // keyword keyword attribute-dict? `:` type 2462 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2463 OperationState &result) { 2464 auto &builder = parser.getBuilder(); 2465 Type type; 2466 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2467 if (parser.parseOperand(ptr) || parser.parseComma() || 2468 parser.parseOperand(cmp) || parser.parseComma() || 2469 parser.parseOperand(val) || 2470 parseAtomicOrdering(parser, result, "success_ordering") || 2471 parseAtomicOrdering(parser, result, "failure_ordering") || 2472 parser.parseOptionalAttrDict(result.attributes) || 2473 parser.parseColonType(type) || 2474 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2475 result.operands) || 2476 parser.resolveOperand(cmp, type, result.operands) || 2477 parser.resolveOperand(val, type, result.operands)) 2478 return failure(); 2479 2480 auto boolType = IntegerType::get(builder.getContext(), 1); 2481 auto resultType = 2482 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2483 result.addTypes(resultType); 2484 2485 return success(); 2486 } 2487 2488 LogicalResult AtomicCmpXchgOp::verify() { 2489 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2490 if (!ptrType) 2491 return emitOpError("expected LLVM IR pointer type for operand #0"); 2492 auto cmpType = getCmp().getType(); 2493 auto valType = getVal().getType(); 2494 if (cmpType != ptrType.getElementType() || cmpType != valType) 2495 return emitOpError("expected LLVM IR element type for operand #0 to " 2496 "match type for all other operands"); 2497 auto intType = valType.dyn_cast<IntegerType>(); 2498 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2499 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2500 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2501 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2502 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2503 return emitOpError("unexpected LLVM IR type"); 2504 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2505 getFailureOrdering() < AtomicOrdering::monotonic) 2506 return emitOpError("ordering must be at least 'monotonic'"); 2507 if (getFailureOrdering() == AtomicOrdering::release || 2508 getFailureOrdering() == AtomicOrdering::acq_rel) 2509 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2510 return success(); 2511 } 2512 2513 //===----------------------------------------------------------------------===// 2514 // Printer, parser and verifier for LLVM::FenceOp. 2515 //===----------------------------------------------------------------------===// 2516 2517 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2518 // attribute-dict? 2519 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2520 StringAttr sScope; 2521 StringRef syncscopeKeyword = "syncscope"; 2522 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2523 if (parser.parseLParen() || 2524 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2525 parser.parseRParen()) 2526 return failure(); 2527 } else { 2528 result.addAttribute(syncscopeKeyword, 2529 parser.getBuilder().getStringAttr("")); 2530 } 2531 if (parseAtomicOrdering(parser, result, "ordering") || 2532 parser.parseOptionalAttrDict(result.attributes)) 2533 return failure(); 2534 return success(); 2535 } 2536 2537 void FenceOp::print(OpAsmPrinter &p) { 2538 StringRef syncscopeKeyword = "syncscope"; 2539 p << ' '; 2540 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2541 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2542 p << stringifyAtomicOrdering(getOrdering()); 2543 } 2544 2545 LogicalResult FenceOp::verify() { 2546 if (getOrdering() == AtomicOrdering::not_atomic || 2547 getOrdering() == AtomicOrdering::unordered || 2548 getOrdering() == AtomicOrdering::monotonic) 2549 return emitOpError("can be given only acquire, release, acq_rel, " 2550 "and seq_cst orderings"); 2551 return success(); 2552 } 2553 2554 //===----------------------------------------------------------------------===// 2555 // Folder for LLVM::BitcastOp 2556 //===----------------------------------------------------------------------===// 2557 2558 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2559 // bitcast(x : T0, T0) -> x 2560 if (getArg().getType() == getType()) 2561 return getArg(); 2562 // bitcast(bitcast(x : T0, T1), T0) -> x 2563 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2564 if (prev.getArg().getType() == getType()) 2565 return prev.getArg(); 2566 return {}; 2567 } 2568 2569 //===----------------------------------------------------------------------===// 2570 // Folder for LLVM::AddrSpaceCastOp 2571 //===----------------------------------------------------------------------===// 2572 2573 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2574 // addrcast(x : T0, T0) -> x 2575 if (getArg().getType() == getType()) 2576 return getArg(); 2577 // addrcast(addrcast(x : T0, T1), T0) -> x 2578 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2579 if (prev.getArg().getType() == getType()) 2580 return prev.getArg(); 2581 return {}; 2582 } 2583 2584 //===----------------------------------------------------------------------===// 2585 // Folder for LLVM::GEPOp 2586 //===----------------------------------------------------------------------===// 2587 2588 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2589 // gep %x:T, 0 -> %x 2590 if (getBase().getType() == getType() && getIndices().size() == 1 && 2591 matchPattern(getIndices()[0], m_Zero())) 2592 return getBase(); 2593 return {}; 2594 } 2595 2596 //===----------------------------------------------------------------------===// 2597 // LLVMDialect initialization, type parsing, and registration. 2598 //===----------------------------------------------------------------------===// 2599 2600 void LLVMDialect::initialize() { 2601 addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>(); 2602 2603 // clang-format off 2604 addTypes<LLVMVoidType, 2605 LLVMPPCFP128Type, 2606 LLVMX86MMXType, 2607 LLVMTokenType, 2608 LLVMLabelType, 2609 LLVMMetadataType, 2610 LLVMFunctionType, 2611 LLVMPointerType, 2612 LLVMFixedVectorType, 2613 LLVMScalableVectorType, 2614 LLVMArrayType, 2615 LLVMStructType>(); 2616 // clang-format on 2617 addOperations< 2618 #define GET_OP_LIST 2619 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2620 , 2621 #define GET_OP_LIST 2622 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 2623 >(); 2624 2625 // Support unknown operations because not all LLVM operations are registered. 2626 allowUnknownOperations(); 2627 } 2628 2629 #define GET_OP_CLASSES 2630 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2631 2632 /// Parse a type registered to this dialect. 2633 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2634 return detail::parseType(parser); 2635 } 2636 2637 /// Print a type registered to this dialect. 2638 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2639 return detail::printType(type, os); 2640 } 2641 2642 LogicalResult LLVMDialect::verifyDataLayoutString( 2643 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2644 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2645 llvm::DataLayout::parse(descr); 2646 if (maybeDataLayout) 2647 return success(); 2648 2649 std::string message; 2650 llvm::raw_string_ostream messageStream(message); 2651 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2652 reportError("invalid data layout descriptor: " + messageStream.str()); 2653 return failure(); 2654 } 2655 2656 /// Verify LLVM dialect attributes. 2657 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2658 NamedAttribute attr) { 2659 // If the `llvm.loop` attribute is present, enforce the following structure, 2660 // which the module translation can assume. 2661 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2662 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2663 if (!loopAttr) 2664 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2665 << "' to be a dictionary attribute"; 2666 Optional<NamedAttribute> parallelAccessGroup = 2667 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2668 if (parallelAccessGroup.hasValue()) { 2669 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2670 if (!accessGroups) 2671 return op->emitOpError() 2672 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2673 << "' to be an array attribute"; 2674 for (Attribute attr : accessGroups) { 2675 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2676 if (!accessGroupRef) 2677 return op->emitOpError() 2678 << "expected '" << attr << "' to be a symbol reference"; 2679 StringAttr metadataName = accessGroupRef.getRootReference(); 2680 auto metadataOp = 2681 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2682 op->getParentOp(), metadataName); 2683 if (!metadataOp) 2684 return op->emitOpError() 2685 << "expected '" << attr << "' to reference a metadata op"; 2686 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2687 Operation *accessGroupOp = 2688 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2689 if (!accessGroupOp) 2690 return op->emitOpError() 2691 << "expected '" << attr << "' to reference an access_group op"; 2692 } 2693 } 2694 2695 Optional<NamedAttribute> loopOptions = 2696 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2697 if (loopOptions.hasValue() && 2698 !loopOptions->getValue().isa<LoopOptionsAttr>()) 2699 return op->emitOpError() 2700 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2701 << "' to be a `loopopts` attribute"; 2702 } 2703 2704 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2705 return op->emitOpError() 2706 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2707 << "' is permitted only in argument or result attributes"; 2708 } 2709 2710 // If the data layout attribute is present, it must use the LLVM data layout 2711 // syntax. Try parsing it and report errors in case of failure. Users of this 2712 // attribute may assume it is well-formed and can pass it to the (asserting) 2713 // llvm::DataLayout constructor. 2714 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2715 return success(); 2716 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2717 return verifyDataLayoutString( 2718 stringAttr.getValue(), 2719 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2720 2721 return op->emitOpError() << "expected '" 2722 << LLVM::LLVMDialect::getDataLayoutAttrName() 2723 << "' to be a string attributes"; 2724 } 2725 2726 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2727 Type annotatedType) { 2728 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2729 if (!structType) { 2730 const auto emitIncorrectAnnotatedType = [&op]() { 2731 return op->emitError() 2732 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2733 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2734 }; 2735 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2736 if (!ptrType) 2737 return emitIncorrectAnnotatedType(); 2738 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2739 if (!structType) 2740 return emitIncorrectAnnotatedType(); 2741 } 2742 2743 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2744 if (!arrAttrs) 2745 return op->emitError() << "expected '" 2746 << LLVMDialect::getStructAttrsAttrName() 2747 << "' to be an array attribute"; 2748 2749 if (structType.getBody().size() != arrAttrs.size()) 2750 return op->emitError() 2751 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2752 << "' must match the size of the annotated '!llvm.struct'"; 2753 return success(); 2754 } 2755 2756 static LogicalResult verifyFuncOpInterfaceStructAttr( 2757 Operation *op, Attribute attr, 2758 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) { 2759 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2760 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2761 return op->emitError() << "expected '" 2762 << LLVMDialect::getStructAttrsAttrName() 2763 << "' to be used on function-like operations"; 2764 } 2765 2766 /// Verify LLVMIR function argument attributes. 2767 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2768 unsigned regionIdx, 2769 unsigned argIdx, 2770 NamedAttribute argAttr) { 2771 // Check that llvm.noalias is a unit attribute. 2772 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2773 !argAttr.getValue().isa<UnitAttr>()) 2774 return op->emitError() 2775 << "expected llvm.noalias argument attribute to be a unit attribute"; 2776 // Check that llvm.align is an integer attribute. 2777 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2778 !argAttr.getValue().isa<IntegerAttr>()) 2779 return op->emitError() 2780 << "llvm.align argument attribute of non integer type"; 2781 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2782 return verifyFuncOpInterfaceStructAttr( 2783 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2784 return funcOp.getArgumentTypes()[argIdx]; 2785 }); 2786 } 2787 return success(); 2788 } 2789 2790 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2791 unsigned regionIdx, 2792 unsigned resIdx, 2793 NamedAttribute resAttr) { 2794 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2795 return verifyFuncOpInterfaceStructAttr( 2796 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2797 return funcOp.getResultTypes()[resIdx]; 2798 }); 2799 } 2800 return success(); 2801 } 2802 2803 //===----------------------------------------------------------------------===// 2804 // Utility functions. 2805 //===----------------------------------------------------------------------===// 2806 2807 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2808 StringRef name, StringRef value, 2809 LLVM::Linkage linkage) { 2810 assert(builder.getInsertionBlock() && 2811 builder.getInsertionBlock()->getParentOp() && 2812 "expected builder to point to a block constrained in an op"); 2813 auto module = 2814 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2815 assert(module && "builder points to an op outside of a module"); 2816 2817 // Create the global at the entry of the module. 2818 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2819 MLIRContext *ctx = builder.getContext(); 2820 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2821 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2822 loc, type, /*isConstant=*/true, linkage, name, 2823 builder.getStringAttr(value), /*alignment=*/0); 2824 2825 // Get the pointer to the first character in the global string. 2826 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2827 Value cst0 = builder.create<LLVM::ConstantOp>( 2828 loc, IntegerType::get(ctx, 64), 2829 builder.getIntegerAttr(builder.getIndexType(), 0)); 2830 return builder.create<LLVM::GEPOp>( 2831 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2832 ValueRange{cst0, cst0}); 2833 } 2834 2835 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2836 return op->hasTrait<OpTrait::SymbolTable>() && 2837 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2838 } 2839 2840 static constexpr const FastmathFlags fastmathFlagsList[] = { 2841 // clang-format off 2842 FastmathFlags::nnan, 2843 FastmathFlags::ninf, 2844 FastmathFlags::nsz, 2845 FastmathFlags::arcp, 2846 FastmathFlags::contract, 2847 FastmathFlags::afn, 2848 FastmathFlags::reassoc, 2849 FastmathFlags::fast, 2850 // clang-format on 2851 }; 2852 2853 void FMFAttr::print(AsmPrinter &printer) const { 2854 printer << "<"; 2855 auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { 2856 return bitEnumContains(this->getFlags(), flag); 2857 }); 2858 llvm::interleaveComma(flags, printer, 2859 [&](auto flag) { printer << stringifyEnum(flag); }); 2860 printer << ">"; 2861 } 2862 2863 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2864 if (failed(parser.parseLess())) 2865 return {}; 2866 2867 FastmathFlags flags = {}; 2868 if (failed(parser.parseOptionalGreater())) { 2869 do { 2870 StringRef elemName; 2871 if (failed(parser.parseKeyword(&elemName))) 2872 return {}; 2873 2874 auto elem = symbolizeFastmathFlags(elemName); 2875 if (!elem) { 2876 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2877 << elemName; 2878 return {}; 2879 } 2880 2881 flags = flags | *elem; 2882 } while (succeeded(parser.parseOptionalComma())); 2883 2884 if (failed(parser.parseGreater())) 2885 return {}; 2886 } 2887 2888 return FMFAttr::get(parser.getContext(), flags); 2889 } 2890 2891 void LinkageAttr::print(AsmPrinter &printer) const { 2892 printer << "<"; 2893 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2894 printer << stringifyEnum(getLinkage()); 2895 else 2896 printer << static_cast<uint64_t>(getLinkage()); 2897 printer << ">"; 2898 } 2899 2900 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2901 StringRef elemName; 2902 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2903 parser.parseGreater()) 2904 return {}; 2905 auto elem = linkage::symbolizeLinkage(elemName); 2906 if (!elem) { 2907 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2908 return {}; 2909 } 2910 Linkage linkage = *elem; 2911 return LinkageAttr::get(parser.getContext(), linkage); 2912 } 2913 2914 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2915 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2916 2917 template <typename T> 2918 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2919 Optional<T> value) { 2920 auto option = llvm::find_if( 2921 options, [tag](auto option) { return option.first == tag; }); 2922 if (option != options.end()) { 2923 if (value.hasValue()) 2924 option->second = *value; 2925 else 2926 options.erase(option); 2927 } else { 2928 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2929 } 2930 return *this; 2931 } 2932 2933 LoopOptionsAttrBuilder & 2934 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 2935 return setOption(LoopOptionCase::disable_licm, value); 2936 } 2937 2938 /// Set the `interleave_count` option to the provided value. If no value 2939 /// is provided the option is deleted. 2940 LoopOptionsAttrBuilder & 2941 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 2942 return setOption(LoopOptionCase::interleave_count, count); 2943 } 2944 2945 /// Set the `disable_unroll` option to the provided value. If no value 2946 /// is provided the option is deleted. 2947 LoopOptionsAttrBuilder & 2948 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 2949 return setOption(LoopOptionCase::disable_unroll, value); 2950 } 2951 2952 /// Set the `disable_pipeline` option to the provided value. If no value 2953 /// is provided the option is deleted. 2954 LoopOptionsAttrBuilder & 2955 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 2956 return setOption(LoopOptionCase::disable_pipeline, value); 2957 } 2958 2959 /// Set the `pipeline_initiation_interval` option to the provided value. 2960 /// If no value is provided the option is deleted. 2961 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 2962 Optional<uint64_t> count) { 2963 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 2964 } 2965 2966 template <typename T> 2967 static Optional<T> 2968 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 2969 LoopOptionCase option) { 2970 auto it = 2971 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 2972 return optionPair.first < option; 2973 }); 2974 if (it == options.end()) 2975 return {}; 2976 return static_cast<T>(it->second); 2977 } 2978 2979 Optional<bool> LoopOptionsAttr::disableUnroll() { 2980 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 2981 } 2982 2983 Optional<bool> LoopOptionsAttr::disableLICM() { 2984 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 2985 } 2986 2987 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 2988 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 2989 } 2990 2991 /// Build the LoopOptions Attribute from a sorted array of individual options. 2992 LoopOptionsAttr LoopOptionsAttr::get( 2993 MLIRContext *context, 2994 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 2995 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 2996 "LoopOptionsAttr ctor expects a sorted options array"); 2997 return Base::get(context, sortedOptions); 2998 } 2999 3000 /// Build the LoopOptions Attribute from a sorted array of individual options. 3001 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 3002 LoopOptionsAttrBuilder &optionBuilders) { 3003 llvm::sort(optionBuilders.options, llvm::less_first()); 3004 return Base::get(context, optionBuilders.options); 3005 } 3006 3007 void LoopOptionsAttr::print(AsmPrinter &printer) const { 3008 printer << "<"; 3009 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 3010 printer << stringifyEnum(option.first) << " = "; 3011 switch (option.first) { 3012 case LoopOptionCase::disable_licm: 3013 case LoopOptionCase::disable_unroll: 3014 case LoopOptionCase::disable_pipeline: 3015 printer << (option.second ? "true" : "false"); 3016 break; 3017 case LoopOptionCase::interleave_count: 3018 case LoopOptionCase::pipeline_initiation_interval: 3019 printer << option.second; 3020 break; 3021 } 3022 }); 3023 printer << ">"; 3024 } 3025 3026 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 3027 if (failed(parser.parseLess())) 3028 return {}; 3029 3030 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 3031 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 3032 do { 3033 StringRef optionName; 3034 if (parser.parseKeyword(&optionName)) 3035 return {}; 3036 3037 auto option = symbolizeLoopOptionCase(optionName); 3038 if (!option) { 3039 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 3040 << optionName; 3041 return {}; 3042 } 3043 if (!seenOptions.insert(*option).second) { 3044 parser.emitError(parser.getNameLoc(), "loop option present twice"); 3045 return {}; 3046 } 3047 if (failed(parser.parseEqual())) 3048 return {}; 3049 3050 int64_t value; 3051 switch (*option) { 3052 case LoopOptionCase::disable_licm: 3053 case LoopOptionCase::disable_unroll: 3054 case LoopOptionCase::disable_pipeline: 3055 if (succeeded(parser.parseOptionalKeyword("true"))) 3056 value = 1; 3057 else if (succeeded(parser.parseOptionalKeyword("false"))) 3058 value = 0; 3059 else { 3060 parser.emitError(parser.getNameLoc(), 3061 "expected boolean value 'true' or 'false'"); 3062 return {}; 3063 } 3064 break; 3065 case LoopOptionCase::interleave_count: 3066 case LoopOptionCase::pipeline_initiation_interval: 3067 if (failed(parser.parseInteger(value))) { 3068 parser.emitError(parser.getNameLoc(), "expected integer value"); 3069 return {}; 3070 } 3071 break; 3072 } 3073 options.push_back(std::make_pair(*option, value)); 3074 } while (succeeded(parser.parseOptionalComma())); 3075 if (failed(parser.parseGreater())) 3076 return {}; 3077 3078 llvm::sort(options, llvm::less_first()); 3079 return get(parser.getContext(), options); 3080 } 3081