1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Mutex.h" 33 #include "llvm/Support/SourceMgr.h" 34 35 #include <numeric> 36 37 using namespace mlir; 38 using namespace mlir::LLVM; 39 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 40 41 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 42 43 static constexpr const char kVolatileAttrName[] = "volatile_"; 44 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 45 static constexpr const char kElemTypeAttrName[] = "elem_type"; 46 47 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 48 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 49 #define GET_ATTRDEF_CLASSES 50 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 51 52 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 53 SmallVector<NamedAttribute, 8> filteredAttrs( 54 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 55 if (attr.getName() == "fastmathFlags") { 56 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 57 return defAttr != attr.getValue(); 58 } 59 return true; 60 })); 61 return filteredAttrs; 62 } 63 64 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 65 NamedAttrList &result) { 66 return parser.parseOptionalAttrDict(result); 67 } 68 69 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 70 DictionaryAttr attrs) { 71 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 72 } 73 74 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 75 /// fully defined llvm.func. 76 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 77 Operation *op, 78 SymbolTableCollection &symbolTable) { 79 StringRef name = symbol.getValue(); 80 auto func = 81 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 82 if (!func) 83 return op->emitOpError("'") 84 << name << "' does not reference a valid LLVM function"; 85 if (func.isExternal()) 86 return op->emitOpError("'") << name << "' does not have a definition"; 87 return success(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // Printing/parsing for LLVM::CmpOp. 92 //===----------------------------------------------------------------------===// 93 94 void ICmpOp::print(OpAsmPrinter &p) { 95 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 96 << ", " << getOperand(1); 97 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 98 p << " : " << getLhs().getType(); 99 } 100 101 void FCmpOp::print(OpAsmPrinter &p) { 102 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 103 << ", " << getOperand(1); 104 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 105 p << " : " << getLhs().getType(); 106 } 107 108 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 109 // attribute-dict? `:` type 110 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 111 // attribute-dict? `:` type 112 template <typename CmpPredicateType> 113 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 114 Builder &builder = parser.getBuilder(); 115 116 StringAttr predicateAttr; 117 OpAsmParser::UnresolvedOperand lhs, rhs; 118 Type type; 119 SMLoc predicateLoc, trailingTypeLoc; 120 if (parser.getCurrentLocation(&predicateLoc) || 121 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 122 parser.parseOperand(lhs) || parser.parseComma() || 123 parser.parseOperand(rhs) || 124 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 125 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 126 parser.resolveOperand(lhs, type, result.operands) || 127 parser.resolveOperand(rhs, type, result.operands)) 128 return failure(); 129 130 // Replace the string attribute `predicate` with an integer attribute. 131 int64_t predicateValue = 0; 132 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 133 Optional<ICmpPredicate> predicate = 134 symbolizeICmpPredicate(predicateAttr.getValue()); 135 if (!predicate) 136 return parser.emitError(predicateLoc) 137 << "'" << predicateAttr.getValue() 138 << "' is an incorrect value of the 'predicate' attribute"; 139 predicateValue = static_cast<int64_t>(predicate.getValue()); 140 } else { 141 Optional<FCmpPredicate> predicate = 142 symbolizeFCmpPredicate(predicateAttr.getValue()); 143 if (!predicate) 144 return parser.emitError(predicateLoc) 145 << "'" << predicateAttr.getValue() 146 << "' is an incorrect value of the 'predicate' attribute"; 147 predicateValue = static_cast<int64_t>(predicate.getValue()); 148 } 149 150 result.attributes.set("predicate", 151 parser.getBuilder().getI64IntegerAttr(predicateValue)); 152 153 // The result type is either i1 or a vector type <? x i1> if the inputs are 154 // vectors. 155 Type resultType = IntegerType::get(builder.getContext(), 1); 156 if (!isCompatibleType(type)) 157 return parser.emitError(trailingTypeLoc, 158 "expected LLVM dialect-compatible type"); 159 if (LLVM::isCompatibleVectorType(type)) { 160 if (LLVM::isScalableVectorType(type)) { 161 resultType = LLVM::getVectorType( 162 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 163 /*isScalable=*/true); 164 } else { 165 resultType = LLVM::getVectorType( 166 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 167 /*isScalable=*/false); 168 } 169 } 170 171 result.addTypes({resultType}); 172 return success(); 173 } 174 175 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 176 return parseCmpOp<ICmpPredicate>(parser, result); 177 } 178 179 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 180 return parseCmpOp<FCmpPredicate>(parser, result); 181 } 182 183 //===----------------------------------------------------------------------===// 184 // Printing, parsing and verification for LLVM::AllocaOp. 185 //===----------------------------------------------------------------------===// 186 187 void AllocaOp::print(OpAsmPrinter &p) { 188 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 189 if (!elemTy) 190 elemTy = *getElemType(); 191 192 auto funcTy = 193 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 194 195 p << ' ' << getArraySize() << " x " << elemTy; 196 if (getAlignment().hasValue() && *getAlignment() != 0) 197 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName}); 198 else 199 p.printOptionalAttrDict((*this)->getAttrs(), 200 {"alignment", kElemTypeAttrName}); 201 p << " : " << funcTy; 202 } 203 204 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 205 // `:` type `,` type 206 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 207 OpAsmParser::UnresolvedOperand arraySize; 208 Type type, elemType; 209 SMLoc trailingTypeLoc; 210 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 211 parser.parseType(elemType) || 212 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 213 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 214 return failure(); 215 216 Optional<NamedAttribute> alignmentAttr = 217 result.attributes.getNamed("alignment"); 218 if (alignmentAttr.hasValue()) { 219 auto alignmentInt = 220 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 221 if (!alignmentInt) 222 return parser.emitError(parser.getNameLoc(), 223 "expected integer alignment"); 224 if (alignmentInt.getValue().isNullValue()) 225 result.attributes.erase("alignment"); 226 } 227 228 // Extract the result type from the trailing function type. 229 auto funcType = type.dyn_cast<FunctionType>(); 230 if (!funcType || funcType.getNumInputs() != 1 || 231 funcType.getNumResults() != 1) 232 return parser.emitError( 233 trailingTypeLoc, 234 "expected trailing function type with one argument and one result"); 235 236 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 237 return failure(); 238 239 Type resultType = funcType.getResult(0); 240 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) { 241 if (ptrResultType.isOpaque()) 242 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 243 } 244 245 result.addTypes({funcType.getResult(0)}); 246 return success(); 247 } 248 249 /// Checks that the elemental type is present in either the pointer type or 250 /// the attribute, but not both. 251 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, 252 Optional<Type> ptrElementType) { 253 if (ptrType.isOpaque() && !ptrElementType.hasValue()) { 254 return op->emitOpError() << "expected '" << kElemTypeAttrName 255 << "' attribute if opaque pointer type is used"; 256 } 257 if (!ptrType.isOpaque() && ptrElementType.hasValue()) { 258 return op->emitOpError() 259 << "unexpected '" << kElemTypeAttrName 260 << "' attribute when non-opaque pointer type is used"; 261 } 262 return success(); 263 } 264 265 LogicalResult AllocaOp::verify() { 266 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(), 267 getElemType()); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // LLVM::BrOp 272 //===----------------------------------------------------------------------===// 273 274 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 275 assert(index == 0 && "invalid successor index"); 276 return SuccessorOperands(getDestOperandsMutable()); 277 } 278 279 //===----------------------------------------------------------------------===// 280 // LLVM::CondBrOp 281 //===----------------------------------------------------------------------===// 282 283 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 284 assert(index < getNumSuccessors() && "invalid successor index"); 285 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 286 : getFalseDestOperandsMutable()); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // LLVM::SwitchOp 291 //===----------------------------------------------------------------------===// 292 293 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 294 Block *defaultDestination, ValueRange defaultOperands, 295 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 296 ArrayRef<ValueRange> caseOperands, 297 ArrayRef<int32_t> branchWeights) { 298 ElementsAttr caseValuesAttr; 299 if (!caseValues.empty()) 300 caseValuesAttr = builder.getI32VectorAttr(caseValues); 301 302 ElementsAttr weightsAttr; 303 if (!branchWeights.empty()) 304 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 305 306 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 307 weightsAttr, defaultDestination, caseDestinations); 308 } 309 310 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 311 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 312 static ParseResult parseSwitchOpCases( 313 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 314 SmallVectorImpl<Block *> &caseDestinations, 315 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 316 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 317 SmallVector<APInt> values; 318 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 319 do { 320 int64_t value = 0; 321 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 322 if (values.empty() && !integerParseResult.hasValue()) 323 return success(); 324 325 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 326 return failure(); 327 values.push_back(APInt(bitWidth, value)); 328 329 Block *destination; 330 SmallVector<OpAsmParser::UnresolvedOperand> operands; 331 SmallVector<Type> operandTypes; 332 if (parser.parseColon() || parser.parseSuccessor(destination)) 333 return failure(); 334 if (!parser.parseOptionalLParen()) { 335 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 336 /*allowResultNumber=*/false) || 337 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 338 return failure(); 339 } 340 caseDestinations.push_back(destination); 341 caseOperands.emplace_back(operands); 342 caseOperandTypes.emplace_back(operandTypes); 343 } while (!parser.parseOptionalComma()); 344 345 ShapedType caseValueType = 346 VectorType::get(static_cast<int64_t>(values.size()), flagType); 347 caseValues = DenseIntElementsAttr::get(caseValueType, values); 348 return success(); 349 } 350 351 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 352 ElementsAttr caseValues, 353 SuccessorRange caseDestinations, 354 OperandRangeRange caseOperands, 355 const TypeRangeRange &caseOperandTypes) { 356 if (!caseValues) 357 return; 358 359 size_t index = 0; 360 llvm::interleave( 361 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 362 [&](auto i) { 363 p << " "; 364 p << std::get<0>(i).getLimitedValue(); 365 p << ": "; 366 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 367 }, 368 [&] { 369 p << ','; 370 p.printNewline(); 371 }); 372 p.printNewline(); 373 } 374 375 LogicalResult SwitchOp::verify() { 376 if ((!getCaseValues() && !getCaseDestinations().empty()) || 377 (getCaseValues() && 378 getCaseValues()->size() != 379 static_cast<int64_t>(getCaseDestinations().size()))) 380 return emitOpError("expects number of case values to match number of " 381 "case destinations"); 382 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 383 return emitError("expects number of branch weights to match number of " 384 "successors: ") 385 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 386 return success(); 387 } 388 389 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 390 assert(index < getNumSuccessors() && "invalid successor index"); 391 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 392 : getCaseOperandsMutable(index - 1)); 393 } 394 395 //===----------------------------------------------------------------------===// 396 // Code for LLVM::GEPOp. 397 //===----------------------------------------------------------------------===// 398 399 constexpr int GEPOp::kDynamicIndex; 400 401 /// Populates `indices` with positions of GEP indices that would correspond to 402 /// LLVMStructTypes potentially nested in the given type. The type currently 403 /// visited gets `currentIndex` and LLVM container types are visited 404 /// recursively. The recursion is bounded and takes care of recursive types by 405 /// means of the `visited` set. 406 static void recordStructIndices(Type type, unsigned currentIndex, 407 SmallVectorImpl<unsigned> &indices, 408 SmallVectorImpl<unsigned> *structSizes, 409 SmallPtrSet<Type, 4> &visited) { 410 if (visited.contains(type)) 411 return; 412 413 visited.insert(type); 414 415 llvm::TypeSwitch<Type>(type) 416 .Case<LLVMStructType>([&](LLVMStructType structType) { 417 indices.push_back(currentIndex); 418 if (structSizes) 419 structSizes->push_back(structType.getBody().size()); 420 for (Type elementType : structType.getBody()) 421 recordStructIndices(elementType, currentIndex + 1, indices, 422 structSizes, visited); 423 }) 424 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 425 LLVMArrayType>([&](auto containerType) { 426 recordStructIndices(containerType.getElementType(), currentIndex + 1, 427 indices, structSizes, visited); 428 }); 429 } 430 431 /// Populates `indices` with positions of GEP indices that correspond to 432 /// LLVMStructTypes potentially nested in the given `baseGEPType`, which must 433 /// be either an LLVMPointer type or a vector thereof. If `structSizes` is 434 /// provided, it is populated with sizes of the indexed structs for bounds 435 /// verification purposes. 436 void GEPOp::findKnownStructIndices(Type sourceElementType, 437 SmallVectorImpl<unsigned> &indices, 438 SmallVectorImpl<unsigned> *structSizes) { 439 SmallPtrSet<Type, 4> visited; 440 recordStructIndices(sourceElementType, /*currentIndex=*/1, indices, 441 structSizes, visited); 442 } 443 444 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 445 Value basePtr, ValueRange operands, 446 ArrayRef<NamedAttribute> attributes) { 447 build(builder, result, resultType, basePtr, operands, 448 SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex), 449 attributes); 450 } 451 452 /// Returns the elemental type of any LLVM-compatible vector type or self. 453 static Type extractVectorElementType(Type type) { 454 if (auto vectorType = type.dyn_cast<VectorType>()) 455 return vectorType.getElementType(); 456 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 457 return scalableVectorType.getElementType(); 458 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 459 return fixedVectorType.getElementType(); 460 return type; 461 } 462 463 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 464 Value basePtr, ValueRange indices, 465 ArrayRef<int32_t> structIndices, 466 ArrayRef<NamedAttribute> attributes) { 467 auto ptrType = 468 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>(); 469 assert(!ptrType.isOpaque() && 470 "expected non-opaque pointer, provide elementType explicitly when " 471 "opaque pointers are used"); 472 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, 473 structIndices, attributes); 474 } 475 476 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 477 Type elementType, Value basePtr, ValueRange indices, 478 ArrayRef<int32_t> structIndices, 479 ArrayRef<NamedAttribute> attributes) { 480 SmallVector<Value> remainingIndices; 481 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 482 structIndices.end()); 483 SmallVector<unsigned> structRelatedPositions; 484 findKnownStructIndices(elementType, structRelatedPositions); 485 486 SmallVector<unsigned> operandsToErase; 487 for (unsigned pos : structRelatedPositions) { 488 // GEP may not be indexing as deep as some structs are located. 489 if (pos >= structIndices.size()) 490 continue; 491 492 // If the index is already static, it's fine. 493 if (structIndices[pos] != kDynamicIndex) 494 continue; 495 496 // Find the corresponding operand. 497 unsigned operandPos = 498 std::count(structIndices.begin(), std::next(structIndices.begin(), pos), 499 kDynamicIndex); 500 501 // Extract the constant value from the operand and put it into the attribute 502 // instead. 503 APInt staticIndexValue; 504 bool matched = 505 matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue)); 506 (void)matched; 507 assert(matched && "index into a struct must be a constant"); 508 assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) && 509 "struct index underflows 32-bit integer"); 510 assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) && 511 "struct index overflows 32-bit integer"); 512 auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 513 updatedStructIndices[pos] = staticIndex; 514 operandsToErase.push_back(operandPos); 515 } 516 517 for (unsigned i = 0, e = indices.size(); i < e; ++i) { 518 if (!llvm::is_contained(operandsToErase, i)) 519 remainingIndices.push_back(indices[i]); 520 } 521 522 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 523 updatedStructIndices, kDynamicIndex)) && 524 "expected as many index operands as dynamic index attr elements"); 525 526 result.addTypes(resultType); 527 result.addAttributes(attributes); 528 result.addAttribute("structIndices", 529 builder.getI32TensorAttr(updatedStructIndices)); 530 if (extractVectorElementType(basePtr.getType()) 531 .cast<LLVMPointerType>() 532 .isOpaque()) 533 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 534 result.addOperands(basePtr); 535 result.addOperands(remainingIndices); 536 } 537 538 static ParseResult 539 parseGEPIndices(OpAsmParser &parser, 540 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 541 DenseIntElementsAttr &structIndices) { 542 SmallVector<int32_t> constantIndices; 543 do { 544 int32_t constantIndex; 545 OptionalParseResult parsedInteger = 546 parser.parseOptionalInteger(constantIndex); 547 if (parsedInteger.hasValue()) { 548 if (failed(parsedInteger.getValue())) 549 return failure(); 550 constantIndices.push_back(constantIndex); 551 continue; 552 } 553 554 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 555 if (failed(parser.parseOperand(indices.emplace_back()))) 556 return failure(); 557 } while (succeeded(parser.parseOptionalComma())); 558 559 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 560 return success(); 561 } 562 563 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 564 OperandRange indices, 565 DenseIntElementsAttr structIndices) { 566 unsigned operandIdx = 0; 567 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 568 [&](int32_t cst) { 569 if (cst == LLVM::GEPOp::kDynamicIndex) 570 printer.printOperand(indices[operandIdx++]); 571 else 572 printer << cst; 573 }); 574 } 575 576 LogicalResult LLVM::GEPOp::verify() { 577 if (failed(verifyOpaquePtr( 578 getOperation(), 579 extractVectorElementType(getType()).cast<LLVMPointerType>(), 580 getElemType()))) 581 return failure(); 582 583 SmallVector<unsigned> indices; 584 SmallVector<unsigned> structSizes; 585 findKnownStructIndices(getSourceElementType(), indices, &structSizes); 586 DenseIntElementsAttr structIndices = getStructIndices(); 587 for (unsigned i : llvm::seq<unsigned>(0, indices.size())) { 588 unsigned index = indices[i]; 589 // GEP may not be indexing as deep as some structs nested in the type. 590 if (index >= structIndices.getNumElements()) 591 continue; 592 593 int32_t staticIndex = structIndices.getValues<int32_t>()[index]; 594 if (staticIndex == LLVM::GEPOp::kDynamicIndex) 595 return emitOpError() << "expected index " << index 596 << " indexing a struct to be constant"; 597 if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i]) 598 return emitOpError() << "index " << index 599 << " indexing a struct is out of bounds"; 600 } 601 return success(); 602 } 603 604 Type LLVM::GEPOp::getSourceElementType() { 605 if (Optional<Type> elemType = getElemType()) 606 return *elemType; 607 608 return extractVectorElementType(getBase().getType()) 609 .cast<LLVMPointerType>() 610 .getElementType(); 611 } 612 613 //===----------------------------------------------------------------------===// 614 // Builder, printer and parser for for LLVM::LoadOp. 615 //===----------------------------------------------------------------------===// 616 617 LogicalResult verifySymbolAttribute( 618 Operation *op, StringRef attributeName, 619 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 620 verifySymbolType) { 621 if (Attribute attribute = op->getAttr(attributeName)) { 622 // The attribute is already verified to be a symbol ref array attribute via 623 // a constraint in the operation definition. 624 for (SymbolRefAttr symbolRef : 625 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 626 StringAttr metadataName = symbolRef.getRootReference(); 627 StringAttr symbolName = symbolRef.getLeafReference(); 628 // We want @metadata::@symbol, not just @symbol 629 if (metadataName == symbolName) { 630 return op->emitOpError() << "expected '" << symbolRef 631 << "' to specify a fully qualified reference"; 632 } 633 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 634 op->getParentOp(), metadataName); 635 if (!metadataOp) 636 return op->emitOpError() 637 << "expected '" << symbolRef << "' to reference a metadata op"; 638 Operation *symbolOp = 639 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 640 if (!symbolOp) 641 return op->emitOpError() 642 << "expected '" << symbolRef << "' to be a valid reference"; 643 if (failed(verifySymbolType(symbolOp, symbolRef))) { 644 return failure(); 645 } 646 } 647 } 648 return success(); 649 } 650 651 // Verifies that metadata ops are wired up properly. 652 template <typename OpTy> 653 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 654 auto verifySymbolType = [op](Operation *symbolOp, 655 SymbolRefAttr symbolRef) -> LogicalResult { 656 if (!isa<OpTy>(symbolOp)) { 657 return op->emitOpError() 658 << "expected '" << symbolRef << "' to resolve to a " 659 << OpTy::getOperationName(); 660 } 661 return success(); 662 }; 663 664 return verifySymbolAttribute(op, attributeName, verifySymbolType); 665 } 666 667 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 668 // access_groups 669 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 670 op, LLVMDialect::getAccessGroupsAttrName()))) 671 return failure(); 672 673 // alias_scopes 674 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 675 op, LLVMDialect::getAliasScopesAttrName()))) 676 return failure(); 677 678 // noalias_scopes 679 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 680 op, LLVMDialect::getNoAliasScopesAttrName()))) 681 return failure(); 682 683 return success(); 684 } 685 686 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 687 688 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 689 Value addr, unsigned alignment, bool isVolatile, 690 bool isNonTemporal) { 691 result.addOperands(addr); 692 result.addTypes(t); 693 if (isVolatile) 694 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 695 if (isNonTemporal) 696 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 697 if (alignment != 0) 698 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 699 } 700 701 void LoadOp::print(OpAsmPrinter &p) { 702 p << ' '; 703 if (getVolatile_()) 704 p << "volatile "; 705 p << getAddr(); 706 p.printOptionalAttrDict((*this)->getAttrs(), 707 {kVolatileAttrName, kElemTypeAttrName}); 708 p << " : " << getAddr().getType(); 709 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 710 p << " -> " << getType(); 711 } 712 713 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 714 // the resulting type if any, null type if opaque pointers are used, and None 715 // if the given type is not the pointer type. 716 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type, 717 SMLoc trailingTypeLoc) { 718 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 719 if (!llvmTy) { 720 parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); 721 return llvm::None; 722 } 723 return llvmTy.getElementType(); 724 } 725 726 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 727 // (`->` type)? 728 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 729 OpAsmParser::UnresolvedOperand addr; 730 Type type; 731 SMLoc trailingTypeLoc; 732 733 if (succeeded(parser.parseOptionalKeyword("volatile"))) 734 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 735 736 if (parser.parseOperand(addr) || 737 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 738 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 739 parser.resolveOperand(addr, type, result.operands)) 740 return failure(); 741 742 Optional<Type> elemTy = 743 getLoadStoreElementType(parser, type, trailingTypeLoc); 744 if (!elemTy) 745 return failure(); 746 if (*elemTy) { 747 result.addTypes(*elemTy); 748 return success(); 749 } 750 751 Type trailingType; 752 if (parser.parseArrow() || parser.parseType(trailingType)) 753 return failure(); 754 result.addTypes(trailingType); 755 return success(); 756 } 757 758 //===----------------------------------------------------------------------===// 759 // Builder, printer and parser for LLVM::StoreOp. 760 //===----------------------------------------------------------------------===// 761 762 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 763 764 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 765 Value addr, unsigned alignment, bool isVolatile, 766 bool isNonTemporal) { 767 result.addOperands({value, addr}); 768 result.addTypes({}); 769 if (isVolatile) 770 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 771 if (isNonTemporal) 772 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 773 if (alignment != 0) 774 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 775 } 776 777 void StoreOp::print(OpAsmPrinter &p) { 778 p << ' '; 779 if (getVolatile_()) 780 p << "volatile "; 781 p << getValue() << ", " << getAddr(); 782 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 783 p << " : "; 784 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 785 p << getValue().getType() << ", "; 786 p << getAddr().getType(); 787 } 788 789 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 790 // attribute-dict? `:` type (`,` type)? 791 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 792 OpAsmParser::UnresolvedOperand addr, value; 793 Type type; 794 SMLoc trailingTypeLoc; 795 796 if (succeeded(parser.parseOptionalKeyword("volatile"))) 797 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 798 799 if (parser.parseOperand(value) || parser.parseComma() || 800 parser.parseOperand(addr) || 801 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 802 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 803 return failure(); 804 805 Type operandType; 806 if (succeeded(parser.parseOptionalComma())) { 807 operandType = type; 808 if (parser.parseType(type)) 809 return failure(); 810 } else { 811 Optional<Type> maybeOperandType = 812 getLoadStoreElementType(parser, type, trailingTypeLoc); 813 if (!maybeOperandType) 814 return failure(); 815 operandType = *maybeOperandType; 816 } 817 818 if (parser.resolveOperand(value, operandType, result.operands) || 819 parser.resolveOperand(addr, type, result.operands)) 820 return failure(); 821 822 return success(); 823 } 824 825 ///===---------------------------------------------------------------------===// 826 /// LLVM::InvokeOp 827 ///===---------------------------------------------------------------------===// 828 829 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 830 assert(index < getNumSuccessors() && "invalid successor index"); 831 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 832 : getUnwindDestOperandsMutable()); 833 } 834 835 LogicalResult InvokeOp::verify() { 836 if (getNumResults() > 1) 837 return emitOpError("must have 0 or 1 result"); 838 839 Block *unwindDest = getUnwindDest(); 840 if (unwindDest->empty()) 841 return emitError("must have at least one operation in unwind destination"); 842 843 // In unwind destination, first operation must be LandingpadOp 844 if (!isa<LandingpadOp>(unwindDest->front())) 845 return emitError("first operation in unwind destination should be a " 846 "llvm.landingpad operation"); 847 848 return success(); 849 } 850 851 void InvokeOp::print(OpAsmPrinter &p) { 852 auto callee = getCallee(); 853 bool isDirect = callee.hasValue(); 854 855 p << ' '; 856 857 // Either function name or pointer 858 if (isDirect) 859 p.printSymbolName(callee.getValue()); 860 else 861 p << getOperand(0); 862 863 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 864 p << " to "; 865 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 866 p << " unwind "; 867 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 868 869 p.printOptionalAttrDict((*this)->getAttrs(), 870 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 871 p << " : "; 872 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 873 getResultTypes()); 874 } 875 876 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 877 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 878 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 879 /// attribute-dict? `:` function-type 880 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 881 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 882 FunctionType funcType; 883 SymbolRefAttr funcAttr; 884 SMLoc trailingTypeLoc; 885 Block *normalDest, *unwindDest; 886 SmallVector<Value, 4> normalOperands, unwindOperands; 887 Builder &builder = parser.getBuilder(); 888 889 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 890 // case of an indirect call, there will be 1 operand before `(`. In case of a 891 // direct call, there will be no operands and the parser will stop at the 892 // function identifier without complaining. 893 if (parser.parseOperandList(operands)) 894 return failure(); 895 bool isDirect = operands.empty(); 896 897 // Optionally parse a function identifier. 898 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 899 return failure(); 900 901 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 902 parser.parseKeyword("to") || 903 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 904 parser.parseKeyword("unwind") || 905 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 906 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 907 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 908 return failure(); 909 910 if (isDirect) { 911 // Make sure types match. 912 if (parser.resolveOperands(operands, funcType.getInputs(), 913 parser.getNameLoc(), result.operands)) 914 return failure(); 915 result.addTypes(funcType.getResults()); 916 } else { 917 // Construct the LLVM IR Dialect function type that the first operand 918 // should match. 919 if (funcType.getNumResults() > 1) 920 return parser.emitError(trailingTypeLoc, 921 "expected function with 0 or 1 result"); 922 923 Type llvmResultType; 924 if (funcType.getNumResults() == 0) { 925 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 926 } else { 927 llvmResultType = funcType.getResult(0); 928 if (!isCompatibleType(llvmResultType)) 929 return parser.emitError(trailingTypeLoc, 930 "expected result to have LLVM type"); 931 } 932 933 SmallVector<Type, 8> argTypes; 934 argTypes.reserve(funcType.getNumInputs()); 935 for (Type ty : funcType.getInputs()) { 936 if (isCompatibleType(ty)) 937 argTypes.push_back(ty); 938 else 939 return parser.emitError(trailingTypeLoc, 940 "expected LLVM types as inputs"); 941 } 942 943 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 944 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 945 946 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 947 948 // Make sure that the first operand (indirect callee) matches the wrapped 949 // LLVM IR function type, and that the types of the other call operands 950 // match the types of the function arguments. 951 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 952 parser.resolveOperands(funcArguments, funcType.getInputs(), 953 parser.getNameLoc(), result.operands)) 954 return failure(); 955 956 result.addTypes(llvmResultType); 957 } 958 result.addSuccessors({normalDest, unwindDest}); 959 result.addOperands(normalOperands); 960 result.addOperands(unwindOperands); 961 962 result.addAttribute( 963 InvokeOp::getOperandSegmentSizeAttr(), 964 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 965 static_cast<int32_t>(normalOperands.size()), 966 static_cast<int32_t>(unwindOperands.size())})); 967 return success(); 968 } 969 970 ///===----------------------------------------------------------------------===// 971 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 972 ///===----------------------------------------------------------------------===// 973 974 LogicalResult LandingpadOp::verify() { 975 Value value; 976 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 977 if (!func.getPersonality().hasValue()) 978 return emitError( 979 "llvm.landingpad needs to be in a function with a personality"); 980 } 981 982 if (!getCleanup() && getOperands().empty()) 983 return emitError("landingpad instruction expects at least one clause or " 984 "cleanup attribute"); 985 986 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 987 value = getOperand(idx); 988 bool isFilter = value.getType().isa<LLVMArrayType>(); 989 if (isFilter) { 990 // FIXME: Verify filter clauses when arrays are appropriately handled 991 } else { 992 // catch - global addresses only. 993 // Bitcast ops should have global addresses as their args. 994 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 995 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 996 continue; 997 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 998 << "global addresses expected as operand to " 999 "bitcast used in clauses for landingpad"; 1000 } 1001 // NullOp and AddressOfOp allowed 1002 if (value.getDefiningOp<NullOp>()) 1003 continue; 1004 if (value.getDefiningOp<AddressOfOp>()) 1005 continue; 1006 return emitError("clause #") 1007 << idx << " is not a known constant - null, addressof, bitcast"; 1008 } 1009 } 1010 return success(); 1011 } 1012 1013 void LandingpadOp::print(OpAsmPrinter &p) { 1014 p << (getCleanup() ? " cleanup " : " "); 1015 1016 // Clauses 1017 for (auto value : getOperands()) { 1018 // Similar to llvm - if clause is an array type then it is filter 1019 // clause else catch clause 1020 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 1021 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1022 << value.getType() << ") "; 1023 } 1024 1025 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1026 1027 p << ": " << getType(); 1028 } 1029 1030 /// <operation> ::= `llvm.landingpad` `cleanup`? 1031 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1032 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1033 // Check for cleanup 1034 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1035 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1036 1037 // Parse clauses with types 1038 while (succeeded(parser.parseOptionalLParen()) && 1039 (succeeded(parser.parseOptionalKeyword("filter")) || 1040 succeeded(parser.parseOptionalKeyword("catch")))) { 1041 OpAsmParser::UnresolvedOperand operand; 1042 Type ty; 1043 if (parser.parseOperand(operand) || parser.parseColon() || 1044 parser.parseType(ty) || 1045 parser.resolveOperand(operand, ty, result.operands) || 1046 parser.parseRParen()) 1047 return failure(); 1048 } 1049 1050 Type type; 1051 if (parser.parseColon() || parser.parseType(type)) 1052 return failure(); 1053 1054 result.addTypes(type); 1055 return success(); 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // Verifying/Printing/parsing for LLVM::CallOp. 1060 //===----------------------------------------------------------------------===// 1061 1062 LogicalResult CallOp::verify() { 1063 if (getNumResults() > 1) 1064 return emitOpError("must have 0 or 1 result"); 1065 1066 // Type for the callee, we'll get it differently depending if it is a direct 1067 // or indirect call. 1068 Type fnType; 1069 1070 bool isIndirect = false; 1071 1072 // If this is an indirect call, the callee attribute is missing. 1073 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1074 if (!calleeName) { 1075 isIndirect = true; 1076 if (!getNumOperands()) 1077 return emitOpError( 1078 "must have either a `callee` attribute or at least an operand"); 1079 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 1080 if (!ptrType) 1081 return emitOpError("indirect call expects a pointer as callee: ") 1082 << ptrType; 1083 fnType = ptrType.getElementType(); 1084 } else { 1085 Operation *callee = 1086 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1087 if (!callee) 1088 return emitOpError() 1089 << "'" << calleeName.getValue() 1090 << "' does not reference a symbol in the current scope"; 1091 auto fn = dyn_cast<LLVMFuncOp>(callee); 1092 if (!fn) 1093 return emitOpError() << "'" << calleeName.getValue() 1094 << "' does not reference a valid LLVM function"; 1095 1096 fnType = fn.getFunctionType(); 1097 } 1098 1099 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1100 if (!funcType) 1101 return emitOpError("callee does not have a functional type: ") << fnType; 1102 1103 // Verify that the operand and result types match the callee. 1104 1105 if (!funcType.isVarArg() && 1106 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1107 return emitOpError() << "incorrect number of operands (" 1108 << (getNumOperands() - isIndirect) 1109 << ") for callee (expecting: " 1110 << funcType.getNumParams() << ")"; 1111 1112 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1113 return emitOpError() << "incorrect number of operands (" 1114 << (getNumOperands() - isIndirect) 1115 << ") for varargs callee (expecting at least: " 1116 << funcType.getNumParams() << ")"; 1117 1118 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1119 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1120 return emitOpError() << "operand type mismatch for operand " << i << ": " 1121 << getOperand(i + isIndirect).getType() 1122 << " != " << funcType.getParamType(i); 1123 1124 if (getNumResults() == 0 && 1125 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1126 return emitOpError() << "expected function call to produce a value"; 1127 1128 if (getNumResults() != 0 && 1129 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1130 return emitOpError() 1131 << "calling function with void result must not produce values"; 1132 1133 if (getNumResults() > 1) 1134 return emitOpError() 1135 << "expected LLVM function call to produce 0 or 1 result"; 1136 1137 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1138 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1139 << " != " << funcType.getReturnType(); 1140 1141 return success(); 1142 } 1143 1144 void CallOp::print(OpAsmPrinter &p) { 1145 auto callee = getCallee(); 1146 bool isDirect = callee.hasValue(); 1147 1148 // Print the direct callee if present as a function attribute, or an indirect 1149 // callee (first operand) otherwise. 1150 p << ' '; 1151 if (isDirect) 1152 p.printSymbolName(callee.getValue()); 1153 else 1154 p << getOperand(0); 1155 1156 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1157 p << '(' << args << ')'; 1158 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1159 1160 // Reconstruct the function MLIR function type from operand and result types. 1161 p << " : "; 1162 p.printFunctionalType(args.getTypes(), getResultTypes()); 1163 } 1164 1165 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1166 // attribute-dict? `:` function-type 1167 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1168 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1169 Type type; 1170 SymbolRefAttr funcAttr; 1171 SMLoc trailingTypeLoc; 1172 1173 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1174 // case of an indirect call, there will be 1 operand before `(`. In case of a 1175 // direct call, there will be no operands and the parser will stop at the 1176 // function identifier without complaining. 1177 if (parser.parseOperandList(operands)) 1178 return failure(); 1179 bool isDirect = operands.empty(); 1180 1181 // Optionally parse a function identifier. 1182 if (isDirect) 1183 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1184 return failure(); 1185 1186 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1187 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1188 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1189 return failure(); 1190 1191 auto funcType = type.dyn_cast<FunctionType>(); 1192 if (!funcType) 1193 return parser.emitError(trailingTypeLoc, "expected function type"); 1194 if (funcType.getNumResults() > 1) 1195 return parser.emitError(trailingTypeLoc, 1196 "expected function with 0 or 1 result"); 1197 if (isDirect) { 1198 // Make sure types match. 1199 if (parser.resolveOperands(operands, funcType.getInputs(), 1200 parser.getNameLoc(), result.operands)) 1201 return failure(); 1202 if (funcType.getNumResults() != 0 && 1203 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1204 result.addTypes(funcType.getResults()); 1205 } else { 1206 Builder &builder = parser.getBuilder(); 1207 Type llvmResultType; 1208 if (funcType.getNumResults() == 0) { 1209 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1210 } else { 1211 llvmResultType = funcType.getResult(0); 1212 if (!isCompatibleType(llvmResultType)) 1213 return parser.emitError(trailingTypeLoc, 1214 "expected result to have LLVM type"); 1215 } 1216 1217 SmallVector<Type, 8> argTypes; 1218 argTypes.reserve(funcType.getNumInputs()); 1219 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1220 auto argType = funcType.getInput(i); 1221 if (!isCompatibleType(argType)) 1222 return parser.emitError(trailingTypeLoc, 1223 "expected LLVM types as inputs"); 1224 argTypes.push_back(argType); 1225 } 1226 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1227 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1228 1229 auto funcArguments = 1230 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1231 1232 // Make sure that the first operand (indirect callee) matches the wrapped 1233 // LLVM IR function type, and that the types of the other call operands 1234 // match the types of the function arguments. 1235 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1236 parser.resolveOperands(funcArguments, funcType.getInputs(), 1237 parser.getNameLoc(), result.operands)) 1238 return failure(); 1239 1240 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1241 result.addTypes(llvmResultType); 1242 } 1243 1244 return success(); 1245 } 1246 1247 //===----------------------------------------------------------------------===// 1248 // Printing/parsing for LLVM::ExtractElementOp. 1249 //===----------------------------------------------------------------------===// 1250 // Expects vector to be of wrapped LLVM vector type and position to be of 1251 // wrapped LLVM i32 type. 1252 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1253 Value vector, Value position, 1254 ArrayRef<NamedAttribute> attrs) { 1255 auto vectorType = vector.getType(); 1256 auto llvmType = LLVM::getVectorElementType(vectorType); 1257 build(b, result, llvmType, vector, position); 1258 result.addAttributes(attrs); 1259 } 1260 1261 void ExtractElementOp::print(OpAsmPrinter &p) { 1262 p << ' ' << getVector() << "[" << getPosition() << " : " 1263 << getPosition().getType() << "]"; 1264 p.printOptionalAttrDict((*this)->getAttrs()); 1265 p << " : " << getVector().getType(); 1266 } 1267 1268 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1269 // attribute-dict? `:` type 1270 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1271 OperationState &result) { 1272 SMLoc loc; 1273 OpAsmParser::UnresolvedOperand vector, position; 1274 Type type, positionType; 1275 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1276 parser.parseLSquare() || parser.parseOperand(position) || 1277 parser.parseColonType(positionType) || parser.parseRSquare() || 1278 parser.parseOptionalAttrDict(result.attributes) || 1279 parser.parseColonType(type) || 1280 parser.resolveOperand(vector, type, result.operands) || 1281 parser.resolveOperand(position, positionType, result.operands)) 1282 return failure(); 1283 if (!LLVM::isCompatibleVectorType(type)) 1284 return parser.emitError( 1285 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1286 result.addTypes(LLVM::getVectorElementType(type)); 1287 return success(); 1288 } 1289 1290 LogicalResult ExtractElementOp::verify() { 1291 Type vectorType = getVector().getType(); 1292 if (!LLVM::isCompatibleVectorType(vectorType)) 1293 return emitOpError("expected LLVM dialect-compatible vector type for " 1294 "operand #1, got") 1295 << vectorType; 1296 Type valueType = LLVM::getVectorElementType(vectorType); 1297 if (valueType != getRes().getType()) 1298 return emitOpError() << "Type mismatch: extracting from " << vectorType 1299 << " should produce " << valueType 1300 << " but this op returns " << getRes().getType(); 1301 return success(); 1302 } 1303 1304 //===----------------------------------------------------------------------===// 1305 // Printing/parsing for LLVM::ExtractValueOp. 1306 //===----------------------------------------------------------------------===// 1307 1308 void ExtractValueOp::print(OpAsmPrinter &p) { 1309 p << ' ' << getContainer() << getPosition(); 1310 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1311 p << " : " << getContainer().getType(); 1312 } 1313 1314 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1315 // `containerType`. Position is an integer array attribute where each value 1316 // is a zero-based position of the element in the aggregate type. Return the 1317 // resulting type wrapped in MLIR, or nullptr on error. 1318 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1319 Type containerType, 1320 ArrayAttr positionAttr, 1321 SMLoc attributeLoc, 1322 SMLoc typeLoc) { 1323 Type llvmType = containerType; 1324 if (!isCompatibleType(containerType)) 1325 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1326 1327 // Infer the element type from the structure type: iteratively step inside the 1328 // type by taking the element type, indexed by the position attribute for 1329 // structures. Check the position index before accessing, it is supposed to 1330 // be in bounds. 1331 for (Attribute subAttr : positionAttr) { 1332 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1333 if (!positionElementAttr) 1334 return parser.emitError(attributeLoc, 1335 "expected an array of integer literals"), 1336 nullptr; 1337 int position = positionElementAttr.getInt(); 1338 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1339 if (position < 0 || 1340 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1341 return parser.emitError(attributeLoc, "position out of bounds"), 1342 nullptr; 1343 llvmType = arrayType.getElementType(); 1344 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1345 if (position < 0 || 1346 static_cast<unsigned>(position) >= structType.getBody().size()) 1347 return parser.emitError(attributeLoc, "position out of bounds"), 1348 nullptr; 1349 llvmType = structType.getBody()[position]; 1350 } else { 1351 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1352 nullptr; 1353 } 1354 } 1355 return llvmType; 1356 } 1357 1358 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1359 // `containerType`. Returns null on failure. 1360 static Type getInsertExtractValueElementType(Type containerType, 1361 ArrayAttr positionAttr, 1362 Operation *op) { 1363 Type llvmType = containerType; 1364 if (!isCompatibleType(containerType)) { 1365 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1366 return {}; 1367 } 1368 1369 // Infer the element type from the structure type: iteratively step inside the 1370 // type by taking the element type, indexed by the position attribute for 1371 // structures. Check the position index before accessing, it is supposed to 1372 // be in bounds. 1373 for (Attribute subAttr : positionAttr) { 1374 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1375 if (!positionElementAttr) { 1376 op->emitOpError("expected an array of integer literals, got: ") 1377 << subAttr; 1378 return {}; 1379 } 1380 int position = positionElementAttr.getInt(); 1381 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1382 if (position < 0 || 1383 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1384 op->emitOpError("position out of bounds: ") << position; 1385 return {}; 1386 } 1387 llvmType = arrayType.getElementType(); 1388 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1389 if (position < 0 || 1390 static_cast<unsigned>(position) >= structType.getBody().size()) { 1391 op->emitOpError("position out of bounds") << position; 1392 return {}; 1393 } 1394 llvmType = structType.getBody()[position]; 1395 } else { 1396 op->emitOpError("expected LLVM IR structure/array type, got: ") 1397 << llvmType; 1398 return {}; 1399 } 1400 } 1401 return llvmType; 1402 } 1403 1404 // <operation> ::= `llvm.extractvalue` ssa-use 1405 // `[` integer-literal (`,` integer-literal)* `]` 1406 // attribute-dict? `:` type 1407 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1408 OpAsmParser::UnresolvedOperand container; 1409 Type containerType; 1410 ArrayAttr positionAttr; 1411 SMLoc attributeLoc, trailingTypeLoc; 1412 1413 if (parser.parseOperand(container) || 1414 parser.getCurrentLocation(&attributeLoc) || 1415 parser.parseAttribute(positionAttr, "position", result.attributes) || 1416 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1417 parser.getCurrentLocation(&trailingTypeLoc) || 1418 parser.parseType(containerType) || 1419 parser.resolveOperand(container, containerType, result.operands)) 1420 return failure(); 1421 1422 auto elementType = getInsertExtractValueElementType( 1423 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1424 if (!elementType) 1425 return failure(); 1426 1427 result.addTypes(elementType); 1428 return success(); 1429 } 1430 1431 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1432 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1433 OpFoldResult result = {}; 1434 while (insertValueOp) { 1435 if (getPosition() == insertValueOp.getPosition()) 1436 return insertValueOp.getValue(); 1437 unsigned min = 1438 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1439 // If one is fully prefix of the other, stop propagating back as it will 1440 // miss dependencies. For instance, %3 should not fold to %f0 in the 1441 // following example: 1442 // ``` 1443 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1444 // !llvm.array<4 x !llvm.array<4xf32>> 1445 // %2 = llvm.insertvalue %arr, %1[0] : 1446 // !llvm.array<4 x !llvm.array<4xf32>> 1447 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1448 // ``` 1449 if (getPosition().getValue().take_front(min) == 1450 insertValueOp.getPosition().getValue().take_front(min)) 1451 return result; 1452 1453 // If neither a prefix, nor the exact position, we can extract out of the 1454 // value being inserted into. Moreover, we can try again if that operand 1455 // is itself an insertvalue expression. 1456 getContainerMutable().assign(insertValueOp.getContainer()); 1457 result = getResult(); 1458 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1459 } 1460 return result; 1461 } 1462 1463 LogicalResult ExtractValueOp::verify() { 1464 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1465 getPositionAttr(), *this); 1466 if (!valueType) 1467 return failure(); 1468 1469 if (getRes().getType() != valueType) 1470 return emitOpError() << "Type mismatch: extracting from " 1471 << getContainer().getType() << " should produce " 1472 << valueType << " but this op returns " 1473 << getRes().getType(); 1474 return success(); 1475 } 1476 1477 //===----------------------------------------------------------------------===// 1478 // Printing/parsing for LLVM::InsertElementOp. 1479 //===----------------------------------------------------------------------===// 1480 1481 void InsertElementOp::print(OpAsmPrinter &p) { 1482 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1483 << getPosition().getType() << "]"; 1484 p.printOptionalAttrDict((*this)->getAttrs()); 1485 p << " : " << getVector().getType(); 1486 } 1487 1488 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1489 // attribute-dict? `:` type 1490 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1491 OperationState &result) { 1492 SMLoc loc; 1493 OpAsmParser::UnresolvedOperand vector, value, position; 1494 Type vectorType, positionType; 1495 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1496 parser.parseComma() || parser.parseOperand(vector) || 1497 parser.parseLSquare() || parser.parseOperand(position) || 1498 parser.parseColonType(positionType) || parser.parseRSquare() || 1499 parser.parseOptionalAttrDict(result.attributes) || 1500 parser.parseColonType(vectorType)) 1501 return failure(); 1502 1503 if (!LLVM::isCompatibleVectorType(vectorType)) 1504 return parser.emitError( 1505 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1506 Type valueType = LLVM::getVectorElementType(vectorType); 1507 if (!valueType) 1508 return failure(); 1509 1510 if (parser.resolveOperand(vector, vectorType, result.operands) || 1511 parser.resolveOperand(value, valueType, result.operands) || 1512 parser.resolveOperand(position, positionType, result.operands)) 1513 return failure(); 1514 1515 result.addTypes(vectorType); 1516 return success(); 1517 } 1518 1519 LogicalResult InsertElementOp::verify() { 1520 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1521 if (valueType != getValue().getType()) 1522 return emitOpError() << "Type mismatch: cannot insert " 1523 << getValue().getType() << " into " 1524 << getVector().getType(); 1525 return success(); 1526 } 1527 1528 //===----------------------------------------------------------------------===// 1529 // Printing/parsing for LLVM::InsertValueOp. 1530 //===----------------------------------------------------------------------===// 1531 1532 void InsertValueOp::print(OpAsmPrinter &p) { 1533 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1534 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1535 p << " : " << getContainer().getType(); 1536 } 1537 1538 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1539 // `[` integer-literal (`,` integer-literal)* `]` 1540 // attribute-dict? `:` type 1541 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1542 OpAsmParser::UnresolvedOperand container, value; 1543 Type containerType; 1544 ArrayAttr positionAttr; 1545 SMLoc attributeLoc, trailingTypeLoc; 1546 1547 if (parser.parseOperand(value) || parser.parseComma() || 1548 parser.parseOperand(container) || 1549 parser.getCurrentLocation(&attributeLoc) || 1550 parser.parseAttribute(positionAttr, "position", result.attributes) || 1551 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1552 parser.getCurrentLocation(&trailingTypeLoc) || 1553 parser.parseType(containerType)) 1554 return failure(); 1555 1556 auto valueType = getInsertExtractValueElementType( 1557 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1558 if (!valueType) 1559 return failure(); 1560 1561 if (parser.resolveOperand(container, containerType, result.operands) || 1562 parser.resolveOperand(value, valueType, result.operands)) 1563 return failure(); 1564 1565 result.addTypes(containerType); 1566 return success(); 1567 } 1568 1569 LogicalResult InsertValueOp::verify() { 1570 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1571 getPositionAttr(), *this); 1572 if (!valueType) 1573 return failure(); 1574 1575 if (getValue().getType() != valueType) 1576 return emitOpError() << "Type mismatch: cannot insert " 1577 << getValue().getType() << " into " 1578 << getContainer().getType(); 1579 1580 return success(); 1581 } 1582 1583 //===----------------------------------------------------------------------===// 1584 // Printing, parsing and verification for LLVM::ReturnOp. 1585 //===----------------------------------------------------------------------===// 1586 1587 LogicalResult ReturnOp::verify() { 1588 if (getNumOperands() > 1) 1589 return emitOpError("expected at most 1 operand"); 1590 1591 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1592 Type expectedType = parent.getFunctionType().getReturnType(); 1593 if (expectedType.isa<LLVMVoidType>()) { 1594 if (getNumOperands() == 0) 1595 return success(); 1596 InFlightDiagnostic diag = emitOpError("expected no operands"); 1597 diag.attachNote(parent->getLoc()) << "when returning from function"; 1598 return diag; 1599 } 1600 if (getNumOperands() == 0) { 1601 if (expectedType.isa<LLVMVoidType>()) 1602 return success(); 1603 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1604 diag.attachNote(parent->getLoc()) << "when returning from function"; 1605 return diag; 1606 } 1607 if (expectedType != getOperand(0).getType()) { 1608 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1609 diag.attachNote(parent->getLoc()) << "when returning from function"; 1610 return diag; 1611 } 1612 } 1613 return success(); 1614 } 1615 1616 //===----------------------------------------------------------------------===// 1617 // ResumeOp 1618 //===----------------------------------------------------------------------===// 1619 1620 LogicalResult ResumeOp::verify() { 1621 if (!getValue().getDefiningOp<LandingpadOp>()) 1622 return emitOpError("expects landingpad value as operand"); 1623 // No check for personality of function - landingpad op verifies it. 1624 return success(); 1625 } 1626 1627 //===----------------------------------------------------------------------===// 1628 // Verifier for LLVM::AddressOfOp. 1629 //===----------------------------------------------------------------------===// 1630 1631 template <typename OpTy> 1632 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1633 Operation *module = parent; 1634 while (module && !satisfiesLLVMModule(module)) 1635 module = module->getParentOp(); 1636 assert(module && "unexpected operation outside of a module"); 1637 return dyn_cast_or_null<OpTy>( 1638 mlir::SymbolTable::lookupSymbolIn(module, name)); 1639 } 1640 1641 GlobalOp AddressOfOp::getGlobal() { 1642 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1643 getGlobalName()); 1644 } 1645 1646 LLVMFuncOp AddressOfOp::getFunction() { 1647 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1648 getGlobalName()); 1649 } 1650 1651 LogicalResult AddressOfOp::verify() { 1652 auto global = getGlobal(); 1653 auto function = getFunction(); 1654 if (!global && !function) 1655 return emitOpError( 1656 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1657 1658 LLVMPointerType type = getType(); 1659 if (global && global.getAddrSpace() != type.getAddressSpace()) 1660 return emitOpError("pointer address space must match address space of the " 1661 "referenced global"); 1662 1663 if (type.isOpaque()) 1664 return success(); 1665 1666 if (global && type.getElementType() != global.getType()) 1667 return emitOpError( 1668 "the type must be a pointer to the type of the referenced global"); 1669 1670 if (function && type.getElementType() != function.getFunctionType()) 1671 return emitOpError( 1672 "the type must be a pointer to the type of the referenced function"); 1673 1674 return success(); 1675 } 1676 1677 //===----------------------------------------------------------------------===// 1678 // Builder, printer and verifier for LLVM::GlobalOp. 1679 //===----------------------------------------------------------------------===// 1680 1681 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1682 bool isConstant, Linkage linkage, StringRef name, 1683 Attribute value, uint64_t alignment, unsigned addrSpace, 1684 bool dsoLocal, bool threadLocal, 1685 ArrayRef<NamedAttribute> attrs) { 1686 result.addAttribute(getSymNameAttrName(result.name), 1687 builder.getStringAttr(name)); 1688 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 1689 if (isConstant) 1690 result.addAttribute(getConstantAttrName(result.name), 1691 builder.getUnitAttr()); 1692 if (value) 1693 result.addAttribute(getValueAttrName(result.name), value); 1694 if (dsoLocal) 1695 result.addAttribute(getDsoLocalAttrName(result.name), 1696 builder.getUnitAttr()); 1697 if (threadLocal) 1698 result.addAttribute(getThreadLocal_AttrName(result.name), 1699 builder.getUnitAttr()); 1700 1701 // Only add an alignment attribute if the "alignment" input 1702 // is different from 0. The value must also be a power of two, but 1703 // this is tested in GlobalOp::verify, not here. 1704 if (alignment != 0) 1705 result.addAttribute(getAlignmentAttrName(result.name), 1706 builder.getI64IntegerAttr(alignment)); 1707 1708 result.addAttribute(getLinkageAttrName(result.name), 1709 LinkageAttr::get(builder.getContext(), linkage)); 1710 if (addrSpace != 0) 1711 result.addAttribute(getAddrSpaceAttrName(result.name), 1712 builder.getI32IntegerAttr(addrSpace)); 1713 result.attributes.append(attrs.begin(), attrs.end()); 1714 result.addRegion(); 1715 } 1716 1717 void GlobalOp::print(OpAsmPrinter &p) { 1718 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1719 if (auto unnamedAddr = getUnnamedAddr()) { 1720 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1721 if (!str.empty()) 1722 p << str << ' '; 1723 } 1724 if (getThreadLocal_()) 1725 p << "thread_local "; 1726 if (getConstant()) 1727 p << "constant "; 1728 p.printSymbolName(getSymName()); 1729 p << '('; 1730 if (auto value = getValueOrNull()) 1731 p.printAttribute(value); 1732 p << ')'; 1733 // Note that the alignment attribute is printed using the 1734 // default syntax here, even though it is an inherent attribute 1735 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1736 p.printOptionalAttrDict( 1737 (*this)->getAttrs(), 1738 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), 1739 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), 1740 getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); 1741 1742 // Print the trailing type unless it's a string global. 1743 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1744 return; 1745 p << " : " << getType(); 1746 1747 Region &initializer = getInitializerRegion(); 1748 if (!initializer.empty()) { 1749 p << ' '; 1750 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1751 } 1752 } 1753 1754 // Parses one of the keywords provided in the list `keywords` and returns the 1755 // position of the parsed keyword in the list. If none of the keywords from the 1756 // list is parsed, returns -1. 1757 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1758 ArrayRef<StringRef> keywords) { 1759 for (const auto &en : llvm::enumerate(keywords)) { 1760 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1761 return en.index(); 1762 } 1763 return -1; 1764 } 1765 1766 namespace { 1767 template <typename Ty> 1768 struct EnumTraits {}; 1769 1770 #define REGISTER_ENUM_TYPE(Ty) \ 1771 template <> \ 1772 struct EnumTraits<Ty> { \ 1773 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1774 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1775 } 1776 1777 REGISTER_ENUM_TYPE(Linkage); 1778 REGISTER_ENUM_TYPE(UnnamedAddr); 1779 } // namespace 1780 1781 /// Parse an enum from the keyword, or default to the provided default value. 1782 /// The return type is the enum type by default, unless overriden with the 1783 /// second template argument. 1784 template <typename EnumTy, typename RetTy = EnumTy> 1785 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1786 OperationState &result, 1787 EnumTy defaultValue) { 1788 SmallVector<StringRef, 10> names; 1789 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1790 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1791 1792 int index = parseOptionalKeywordAlternative(parser, names); 1793 if (index == -1) 1794 return static_cast<RetTy>(defaultValue); 1795 return static_cast<RetTy>(index); 1796 } 1797 1798 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1799 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1800 // align ::= `align` `=` UINT64 1801 // 1802 // The type can be omitted for string attributes, in which case it will be 1803 // inferred from the value of the string as [strlen(value) x i8]. 1804 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1805 MLIRContext *ctx = parser.getContext(); 1806 // Parse optional linkage, default to External. 1807 result.addAttribute(getLinkageAttrName(result.name), 1808 LLVM::LinkageAttr::get( 1809 ctx, parseOptionalLLVMKeyword<Linkage>( 1810 parser, result, LLVM::Linkage::External))); 1811 1812 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 1813 result.addAttribute(getThreadLocal_AttrName(result.name), 1814 parser.getBuilder().getUnitAttr()); 1815 1816 // Parse optional UnnamedAddr, default to None. 1817 result.addAttribute(getUnnamedAddrAttrName(result.name), 1818 parser.getBuilder().getI64IntegerAttr( 1819 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1820 parser, result, LLVM::UnnamedAddr::None))); 1821 1822 if (succeeded(parser.parseOptionalKeyword("constant"))) 1823 result.addAttribute(getConstantAttrName(result.name), 1824 parser.getBuilder().getUnitAttr()); 1825 1826 StringAttr name; 1827 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 1828 result.attributes) || 1829 parser.parseLParen()) 1830 return failure(); 1831 1832 Attribute value; 1833 if (parser.parseOptionalRParen()) { 1834 if (parser.parseAttribute(value, getValueAttrName(result.name), 1835 result.attributes) || 1836 parser.parseRParen()) 1837 return failure(); 1838 } 1839 1840 SmallVector<Type, 1> types; 1841 if (parser.parseOptionalAttrDict(result.attributes) || 1842 parser.parseOptionalColonTypeList(types)) 1843 return failure(); 1844 1845 if (types.size() > 1) 1846 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1847 1848 Region &initRegion = *result.addRegion(); 1849 if (types.empty()) { 1850 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1851 MLIRContext *context = parser.getContext(); 1852 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1853 strAttr.getValue().size()); 1854 types.push_back(arrayType); 1855 } else { 1856 return parser.emitError(parser.getNameLoc(), 1857 "type can only be omitted for string globals"); 1858 } 1859 } else { 1860 OptionalParseResult parseResult = 1861 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1862 /*argTypes=*/{}); 1863 if (parseResult.hasValue() && failed(*parseResult)) 1864 return failure(); 1865 } 1866 1867 result.addAttribute(getGlobalTypeAttrName(result.name), 1868 TypeAttr::get(types[0])); 1869 return success(); 1870 } 1871 1872 static bool isZeroAttribute(Attribute value) { 1873 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1874 return intValue.getValue().isNullValue(); 1875 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1876 return fpValue.getValue().isZero(); 1877 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1878 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1879 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1880 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1881 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1882 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1883 return false; 1884 } 1885 1886 LogicalResult GlobalOp::verify() { 1887 if (!LLVMPointerType::isValidElementType(getType())) 1888 return emitOpError( 1889 "expects type to be a valid element type for an LLVM pointer"); 1890 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1891 return emitOpError("must appear at the module level"); 1892 1893 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1894 auto type = getType().dyn_cast<LLVMArrayType>(); 1895 IntegerType elementType = 1896 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1897 if (!elementType || elementType.getWidth() != 8 || 1898 type.getNumElements() != strAttr.getValue().size()) 1899 return emitOpError( 1900 "requires an i8 array type of the length equal to that of the string " 1901 "attribute"); 1902 } 1903 1904 if (getLinkage() == Linkage::Common) { 1905 if (Attribute value = getValueOrNull()) { 1906 if (!isZeroAttribute(value)) { 1907 return emitOpError() 1908 << "expected zero value for '" 1909 << stringifyLinkage(Linkage::Common) << "' linkage"; 1910 } 1911 } 1912 } 1913 1914 if (getLinkage() == Linkage::Appending) { 1915 if (!getType().isa<LLVMArrayType>()) { 1916 return emitOpError() << "expected array type for '" 1917 << stringifyLinkage(Linkage::Appending) 1918 << "' linkage"; 1919 } 1920 } 1921 1922 Optional<uint64_t> alignAttr = getAlignment(); 1923 if (alignAttr.hasValue()) { 1924 uint64_t value = alignAttr.getValue(); 1925 if (!llvm::isPowerOf2_64(value)) 1926 return emitError() << "alignment attribute is not a power of 2"; 1927 } 1928 1929 return success(); 1930 } 1931 1932 LogicalResult GlobalOp::verifyRegions() { 1933 if (Block *b = getInitializerBlock()) { 1934 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1935 if (ret.operand_type_begin() == ret.operand_type_end()) 1936 return emitOpError("initializer region cannot return void"); 1937 if (*ret.operand_type_begin() != getType()) 1938 return emitOpError("initializer region type ") 1939 << *ret.operand_type_begin() << " does not match global type " 1940 << getType(); 1941 1942 for (Operation &op : *b) { 1943 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 1944 if (!iface || !iface.hasNoEffect()) 1945 return op.emitError() 1946 << "ops with side effects not allowed in global initializers"; 1947 } 1948 1949 if (getValueOrNull()) 1950 return emitOpError("cannot have both initializer value and region"); 1951 } 1952 1953 return success(); 1954 } 1955 1956 //===----------------------------------------------------------------------===// 1957 // LLVM::GlobalCtorsOp 1958 //===----------------------------------------------------------------------===// 1959 1960 LogicalResult 1961 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1962 for (Attribute ctor : getCtors()) { 1963 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 1964 symbolTable))) 1965 return failure(); 1966 } 1967 return success(); 1968 } 1969 1970 LogicalResult GlobalCtorsOp::verify() { 1971 if (getCtors().size() != getPriorities().size()) 1972 return emitError( 1973 "mismatch between the number of ctors and the number of priorities"); 1974 return success(); 1975 } 1976 1977 //===----------------------------------------------------------------------===// 1978 // LLVM::GlobalDtorsOp 1979 //===----------------------------------------------------------------------===// 1980 1981 LogicalResult 1982 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1983 for (Attribute dtor : getDtors()) { 1984 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 1985 symbolTable))) 1986 return failure(); 1987 } 1988 return success(); 1989 } 1990 1991 LogicalResult GlobalDtorsOp::verify() { 1992 if (getDtors().size() != getPriorities().size()) 1993 return emitError( 1994 "mismatch between the number of dtors and the number of priorities"); 1995 return success(); 1996 } 1997 1998 //===----------------------------------------------------------------------===// 1999 // Printing/parsing for LLVM::ShuffleVectorOp. 2000 //===----------------------------------------------------------------------===// 2001 // Expects vector to be of wrapped LLVM vector type and position to be of 2002 // wrapped LLVM i32 type. 2003 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 2004 Value v1, Value v2, ArrayAttr mask, 2005 ArrayRef<NamedAttribute> attrs) { 2006 auto containerType = v1.getType(); 2007 auto vType = LLVM::getVectorType( 2008 LLVM::getVectorElementType(containerType), mask.size(), 2009 containerType.cast<VectorType>().isScalable()); 2010 build(b, result, vType, v1, v2, mask); 2011 result.addAttributes(attrs); 2012 } 2013 2014 void ShuffleVectorOp::print(OpAsmPrinter &p) { 2015 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 2016 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 2017 p << " : " << getV1().getType() << ", " << getV2().getType(); 2018 } 2019 2020 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 2021 // `[` integer-literal (`,` integer-literal)* `]` 2022 // attribute-dict? `:` type 2023 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 2024 OperationState &result) { 2025 SMLoc loc; 2026 OpAsmParser::UnresolvedOperand v1, v2; 2027 ArrayAttr maskAttr; 2028 Type typeV1, typeV2; 2029 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 2030 parser.parseComma() || parser.parseOperand(v2) || 2031 parser.parseAttribute(maskAttr, "mask", result.attributes) || 2032 parser.parseOptionalAttrDict(result.attributes) || 2033 parser.parseColonType(typeV1) || parser.parseComma() || 2034 parser.parseType(typeV2) || 2035 parser.resolveOperand(v1, typeV1, result.operands) || 2036 parser.resolveOperand(v2, typeV2, result.operands)) 2037 return failure(); 2038 if (!LLVM::isCompatibleVectorType(typeV1)) 2039 return parser.emitError( 2040 loc, "expected LLVM IR dialect vector type for operand #1"); 2041 auto vType = 2042 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 2043 typeV1.cast<VectorType>().isScalable()); 2044 result.addTypes(vType); 2045 return success(); 2046 } 2047 2048 LogicalResult ShuffleVectorOp::verify() { 2049 Type type1 = getV1().getType(); 2050 Type type2 = getV2().getType(); 2051 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 2052 return emitOpError("expected matching LLVM IR Dialect element types"); 2053 if (LLVM::isScalableVectorType(type1)) 2054 if (llvm::any_of(getMask(), [](Attribute attr) { 2055 return attr.cast<IntegerAttr>().getInt() != 0; 2056 })) 2057 return emitOpError("expected a splat operation for scalable vectors"); 2058 return success(); 2059 } 2060 2061 //===----------------------------------------------------------------------===// 2062 // Implementations for LLVM::LLVMFuncOp. 2063 //===----------------------------------------------------------------------===// 2064 2065 // Add the entry block to the function. 2066 Block *LLVMFuncOp::addEntryBlock() { 2067 assert(empty() && "function already has an entry block"); 2068 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 2069 2070 auto *entry = new Block; 2071 push_back(entry); 2072 2073 // FIXME: Allow passing in proper locations for the entry arguments. 2074 LLVMFunctionType type = getFunctionType(); 2075 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2076 entry->addArgument(type.getParamType(i), getLoc()); 2077 return entry; 2078 } 2079 2080 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2081 StringRef name, Type type, LLVM::Linkage linkage, 2082 bool dsoLocal, ArrayRef<NamedAttribute> attrs, 2083 ArrayRef<DictionaryAttr> argAttrs) { 2084 result.addRegion(); 2085 result.addAttribute(SymbolTable::getSymbolAttrName(), 2086 builder.getStringAttr(name)); 2087 result.addAttribute(getFunctionTypeAttrName(result.name), 2088 TypeAttr::get(type)); 2089 result.addAttribute(getLinkageAttrName(result.name), 2090 LinkageAttr::get(builder.getContext(), linkage)); 2091 result.attributes.append(attrs.begin(), attrs.end()); 2092 if (dsoLocal) 2093 result.addAttribute("dso_local", builder.getUnitAttr()); 2094 if (argAttrs.empty()) 2095 return; 2096 2097 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 2098 "expected as many argument attribute lists as arguments"); 2099 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 2100 /*resultAttrs=*/llvm::None); 2101 } 2102 2103 // Builds an LLVM function type from the given lists of input and output types. 2104 // Returns a null type if any of the types provided are non-LLVM types, or if 2105 // there is more than one output type. 2106 static Type 2107 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2108 ArrayRef<Type> outputs, 2109 function_interface_impl::VariadicFlag variadicFlag) { 2110 Builder &b = parser.getBuilder(); 2111 if (outputs.size() > 1) { 2112 parser.emitError(loc, "failed to construct function type: expected zero or " 2113 "one function result"); 2114 return {}; 2115 } 2116 2117 // Convert inputs to LLVM types, exit early on error. 2118 SmallVector<Type, 4> llvmInputs; 2119 for (auto t : inputs) { 2120 if (!isCompatibleType(t)) { 2121 parser.emitError(loc, "failed to construct function type: expected LLVM " 2122 "type for function arguments"); 2123 return {}; 2124 } 2125 llvmInputs.push_back(t); 2126 } 2127 2128 // No output is denoted as "void" in LLVM type system. 2129 Type llvmOutput = 2130 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2131 if (!isCompatibleType(llvmOutput)) { 2132 parser.emitError(loc, "failed to construct function type: expected LLVM " 2133 "type for function results") 2134 << llvmOutput; 2135 return {}; 2136 } 2137 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2138 variadicFlag.isVariadic()); 2139 } 2140 2141 // Parses an LLVM function. 2142 // 2143 // operation ::= `llvm.func` linkage? function-signature function-attributes? 2144 // function-body 2145 // 2146 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2147 // Default to external linkage if no keyword is provided. 2148 result.addAttribute( 2149 getLinkageAttrName(result.name), 2150 LinkageAttr::get(parser.getContext(), 2151 parseOptionalLLVMKeyword<Linkage>( 2152 parser, result, LLVM::Linkage::External))); 2153 2154 StringAttr nameAttr; 2155 SmallVector<OpAsmParser::Argument> entryArgs; 2156 SmallVector<DictionaryAttr> resultAttrs; 2157 SmallVector<Type> resultTypes; 2158 bool isVariadic; 2159 2160 auto signatureLocation = parser.getCurrentLocation(); 2161 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2162 result.attributes) || 2163 function_interface_impl::parseFunctionSignature( 2164 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, 2165 resultAttrs)) 2166 return failure(); 2167 2168 SmallVector<Type> argTypes; 2169 for (auto &arg : entryArgs) 2170 argTypes.push_back(arg.type); 2171 auto type = 2172 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2173 function_interface_impl::VariadicFlag(isVariadic)); 2174 if (!type) 2175 return failure(); 2176 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2177 TypeAttr::get(type)); 2178 2179 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2180 return failure(); 2181 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2182 entryArgs, resultAttrs); 2183 2184 auto *body = result.addRegion(); 2185 OptionalParseResult parseResult = 2186 parser.parseOptionalRegion(*body, entryArgs); 2187 return failure(parseResult.hasValue() && failed(*parseResult)); 2188 } 2189 2190 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2191 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2192 // the external linkage since it is the default value. 2193 void LLVMFuncOp::print(OpAsmPrinter &p) { 2194 p << ' '; 2195 if (getLinkage() != LLVM::Linkage::External) 2196 p << stringifyLinkage(getLinkage()) << ' '; 2197 p.printSymbolName(getName()); 2198 2199 LLVMFunctionType fnType = getFunctionType(); 2200 SmallVector<Type, 8> argTypes; 2201 SmallVector<Type, 1> resTypes; 2202 argTypes.reserve(fnType.getNumParams()); 2203 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2204 argTypes.push_back(fnType.getParamType(i)); 2205 2206 Type returnType = fnType.getReturnType(); 2207 if (!returnType.isa<LLVMVoidType>()) 2208 resTypes.push_back(returnType); 2209 2210 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2211 isVarArg(), resTypes); 2212 function_interface_impl::printFunctionAttributes( 2213 p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); 2214 2215 // Print the body if this is not an external function. 2216 Region &body = getBody(); 2217 if (!body.empty()) { 2218 p << ' '; 2219 p.printRegion(body, /*printEntryBlockArgs=*/false, 2220 /*printBlockTerminators=*/true); 2221 } 2222 } 2223 2224 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2225 // - functions don't have 'common' linkage 2226 // - external functions have 'external' or 'extern_weak' linkage; 2227 // - vararg is (currently) only supported for external functions; 2228 LogicalResult LLVMFuncOp::verify() { 2229 if (getLinkage() == LLVM::Linkage::Common) 2230 return emitOpError() << "functions cannot have '" 2231 << stringifyLinkage(LLVM::Linkage::Common) 2232 << "' linkage"; 2233 2234 // Check to see if this function has a void return with a result attribute to 2235 // it. It isn't clear what semantics we would assign to that. 2236 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2237 !getResultAttrs(0).empty()) { 2238 return emitOpError() 2239 << "cannot attach result attributes to functions with a void return"; 2240 } 2241 2242 if (isExternal()) { 2243 if (getLinkage() != LLVM::Linkage::External && 2244 getLinkage() != LLVM::Linkage::ExternWeak) 2245 return emitOpError() << "external functions must have '" 2246 << stringifyLinkage(LLVM::Linkage::External) 2247 << "' or '" 2248 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2249 << "' linkage"; 2250 return success(); 2251 } 2252 2253 if (isVarArg()) 2254 return emitOpError("only external functions can be variadic"); 2255 2256 return success(); 2257 } 2258 2259 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2260 /// - entry block arguments are of LLVM types. 2261 LogicalResult LLVMFuncOp::verifyRegions() { 2262 if (isExternal()) 2263 return success(); 2264 2265 unsigned numArguments = getFunctionType().getNumParams(); 2266 Block &entryBlock = front(); 2267 for (unsigned i = 0; i < numArguments; ++i) { 2268 Type argType = entryBlock.getArgument(i).getType(); 2269 if (!isCompatibleType(argType)) 2270 return emitOpError("entry block argument #") 2271 << i << " is not of LLVM type"; 2272 } 2273 2274 return success(); 2275 } 2276 2277 //===----------------------------------------------------------------------===// 2278 // Verification for LLVM::ConstantOp. 2279 //===----------------------------------------------------------------------===// 2280 2281 LogicalResult LLVM::ConstantOp::verify() { 2282 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2283 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2284 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2285 !arrayType.getElementType().isInteger(8)) { 2286 return emitOpError() << "expected array type of " 2287 << sAttr.getValue().size() 2288 << " i8 elements for the string constant"; 2289 } 2290 return success(); 2291 } 2292 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2293 if (structType.getBody().size() != 2 || 2294 structType.getBody()[0] != structType.getBody()[1]) { 2295 return emitError() << "expected struct type with two elements of the " 2296 "same type, the type of a complex constant"; 2297 } 2298 2299 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2300 if (!arrayAttr || arrayAttr.size() != 2 || 2301 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2302 return emitOpError() << "expected array attribute with two elements, " 2303 "representing a complex constant"; 2304 } 2305 2306 Type elementType = structType.getBody()[0]; 2307 if (!elementType 2308 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2309 return emitError() 2310 << "expected struct element types to be floating point type or " 2311 "integer type"; 2312 } 2313 return success(); 2314 } 2315 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2316 return emitOpError() 2317 << "only supports integer, float, string or elements attributes"; 2318 return success(); 2319 } 2320 2321 // Constant op constant-folds to its value. 2322 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2323 2324 //===----------------------------------------------------------------------===// 2325 // Utility functions for parsing atomic ops 2326 //===----------------------------------------------------------------------===// 2327 2328 // Helper function to parse a keyword into the specified attribute named by 2329 // `attrName`. The keyword must match one of the string values defined by the 2330 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2331 // state. 2332 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2333 StringRef attrName) { 2334 SMLoc loc; 2335 StringRef keyword; 2336 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2337 return failure(); 2338 2339 // Replace the keyword `keyword` with an integer attribute. 2340 auto kind = symbolizeAtomicBinOp(keyword); 2341 if (!kind) { 2342 return parser.emitError(loc) 2343 << "'" << keyword << "' is an incorrect value of the '" << attrName 2344 << "' attribute"; 2345 } 2346 2347 auto value = static_cast<int64_t>(kind.getValue()); 2348 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2349 result.addAttribute(attrName, attr); 2350 2351 return success(); 2352 } 2353 2354 // Helper function to parse a keyword into the specified attribute named by 2355 // `attrName`. The keyword must match one of the string values defined by the 2356 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2357 // state. 2358 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2359 OperationState &result, 2360 StringRef attrName) { 2361 SMLoc loc; 2362 StringRef ordering; 2363 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2364 return failure(); 2365 2366 // Replace the keyword `ordering` with an integer attribute. 2367 auto kind = symbolizeAtomicOrdering(ordering); 2368 if (!kind) { 2369 return parser.emitError(loc) 2370 << "'" << ordering << "' is an incorrect value of the '" << attrName 2371 << "' attribute"; 2372 } 2373 2374 auto value = static_cast<int64_t>(kind.getValue()); 2375 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2376 result.addAttribute(attrName, attr); 2377 2378 return success(); 2379 } 2380 2381 //===----------------------------------------------------------------------===// 2382 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2383 //===----------------------------------------------------------------------===// 2384 2385 void AtomicRMWOp::print(OpAsmPrinter &p) { 2386 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2387 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2388 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2389 p << " : " << getRes().getType(); 2390 } 2391 2392 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2393 // attribute-dict? `:` type 2394 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2395 Type type; 2396 OpAsmParser::UnresolvedOperand ptr, val; 2397 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2398 parser.parseComma() || parser.parseOperand(val) || 2399 parseAtomicOrdering(parser, result, "ordering") || 2400 parser.parseOptionalAttrDict(result.attributes) || 2401 parser.parseColonType(type) || 2402 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2403 result.operands) || 2404 parser.resolveOperand(val, type, result.operands)) 2405 return failure(); 2406 2407 result.addTypes(type); 2408 return success(); 2409 } 2410 2411 LogicalResult AtomicRMWOp::verify() { 2412 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2413 auto valType = getVal().getType(); 2414 if (valType != ptrType.getElementType()) 2415 return emitOpError("expected LLVM IR element type for operand #0 to " 2416 "match type for operand #1"); 2417 auto resType = getRes().getType(); 2418 if (resType != valType) 2419 return emitOpError( 2420 "expected LLVM IR result type to match type for operand #1"); 2421 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2422 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2423 return emitOpError("expected LLVM IR floating point type"); 2424 } else if (getBinOp() == AtomicBinOp::xchg) { 2425 auto intType = valType.dyn_cast<IntegerType>(); 2426 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2427 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2428 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2429 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2430 !valType.isa<Float64Type>()) 2431 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2432 } else { 2433 auto intType = valType.dyn_cast<IntegerType>(); 2434 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2435 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2436 intBitWidth != 64) 2437 return emitOpError("expected LLVM IR integer type"); 2438 } 2439 2440 if (static_cast<unsigned>(getOrdering()) < 2441 static_cast<unsigned>(AtomicOrdering::monotonic)) 2442 return emitOpError() << "expected at least '" 2443 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2444 << "' ordering"; 2445 2446 return success(); 2447 } 2448 2449 //===----------------------------------------------------------------------===// 2450 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2451 //===----------------------------------------------------------------------===// 2452 2453 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2454 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2455 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2456 << stringifyAtomicOrdering(getFailureOrdering()); 2457 p.printOptionalAttrDict((*this)->getAttrs(), 2458 {"success_ordering", "failure_ordering"}); 2459 p << " : " << getVal().getType(); 2460 } 2461 2462 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2463 // keyword keyword attribute-dict? `:` type 2464 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2465 OperationState &result) { 2466 auto &builder = parser.getBuilder(); 2467 Type type; 2468 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2469 if (parser.parseOperand(ptr) || parser.parseComma() || 2470 parser.parseOperand(cmp) || parser.parseComma() || 2471 parser.parseOperand(val) || 2472 parseAtomicOrdering(parser, result, "success_ordering") || 2473 parseAtomicOrdering(parser, result, "failure_ordering") || 2474 parser.parseOptionalAttrDict(result.attributes) || 2475 parser.parseColonType(type) || 2476 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2477 result.operands) || 2478 parser.resolveOperand(cmp, type, result.operands) || 2479 parser.resolveOperand(val, type, result.operands)) 2480 return failure(); 2481 2482 auto boolType = IntegerType::get(builder.getContext(), 1); 2483 auto resultType = 2484 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2485 result.addTypes(resultType); 2486 2487 return success(); 2488 } 2489 2490 LogicalResult AtomicCmpXchgOp::verify() { 2491 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2492 if (!ptrType) 2493 return emitOpError("expected LLVM IR pointer type for operand #0"); 2494 auto cmpType = getCmp().getType(); 2495 auto valType = getVal().getType(); 2496 if (cmpType != ptrType.getElementType() || cmpType != valType) 2497 return emitOpError("expected LLVM IR element type for operand #0 to " 2498 "match type for all other operands"); 2499 auto intType = valType.dyn_cast<IntegerType>(); 2500 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2501 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2502 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2503 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2504 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2505 return emitOpError("unexpected LLVM IR type"); 2506 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2507 getFailureOrdering() < AtomicOrdering::monotonic) 2508 return emitOpError("ordering must be at least 'monotonic'"); 2509 if (getFailureOrdering() == AtomicOrdering::release || 2510 getFailureOrdering() == AtomicOrdering::acq_rel) 2511 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2512 return success(); 2513 } 2514 2515 //===----------------------------------------------------------------------===// 2516 // Printer, parser and verifier for LLVM::FenceOp. 2517 //===----------------------------------------------------------------------===// 2518 2519 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2520 // attribute-dict? 2521 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2522 StringAttr sScope; 2523 StringRef syncscopeKeyword = "syncscope"; 2524 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2525 if (parser.parseLParen() || 2526 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2527 parser.parseRParen()) 2528 return failure(); 2529 } else { 2530 result.addAttribute(syncscopeKeyword, 2531 parser.getBuilder().getStringAttr("")); 2532 } 2533 if (parseAtomicOrdering(parser, result, "ordering") || 2534 parser.parseOptionalAttrDict(result.attributes)) 2535 return failure(); 2536 return success(); 2537 } 2538 2539 void FenceOp::print(OpAsmPrinter &p) { 2540 StringRef syncscopeKeyword = "syncscope"; 2541 p << ' '; 2542 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2543 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2544 p << stringifyAtomicOrdering(getOrdering()); 2545 } 2546 2547 LogicalResult FenceOp::verify() { 2548 if (getOrdering() == AtomicOrdering::not_atomic || 2549 getOrdering() == AtomicOrdering::unordered || 2550 getOrdering() == AtomicOrdering::monotonic) 2551 return emitOpError("can be given only acquire, release, acq_rel, " 2552 "and seq_cst orderings"); 2553 return success(); 2554 } 2555 2556 //===----------------------------------------------------------------------===// 2557 // Folder for LLVM::BitcastOp 2558 //===----------------------------------------------------------------------===// 2559 2560 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2561 // bitcast(x : T0, T0) -> x 2562 if (getArg().getType() == getType()) 2563 return getArg(); 2564 // bitcast(bitcast(x : T0, T1), T0) -> x 2565 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2566 if (prev.getArg().getType() == getType()) 2567 return prev.getArg(); 2568 return {}; 2569 } 2570 2571 //===----------------------------------------------------------------------===// 2572 // Folder for LLVM::AddrSpaceCastOp 2573 //===----------------------------------------------------------------------===// 2574 2575 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2576 // addrcast(x : T0, T0) -> x 2577 if (getArg().getType() == getType()) 2578 return getArg(); 2579 // addrcast(addrcast(x : T0, T1), T0) -> x 2580 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2581 if (prev.getArg().getType() == getType()) 2582 return prev.getArg(); 2583 return {}; 2584 } 2585 2586 //===----------------------------------------------------------------------===// 2587 // Folder for LLVM::GEPOp 2588 //===----------------------------------------------------------------------===// 2589 2590 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2591 // gep %x:T, 0 -> %x 2592 if (getBase().getType() == getType() && getIndices().size() == 1 && 2593 matchPattern(getIndices()[0], m_Zero())) 2594 return getBase(); 2595 return {}; 2596 } 2597 2598 //===----------------------------------------------------------------------===// 2599 // LLVMDialect initialization, type parsing, and registration. 2600 //===----------------------------------------------------------------------===// 2601 2602 void LLVMDialect::initialize() { 2603 addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>(); 2604 2605 // clang-format off 2606 addTypes<LLVMVoidType, 2607 LLVMPPCFP128Type, 2608 LLVMX86MMXType, 2609 LLVMTokenType, 2610 LLVMLabelType, 2611 LLVMMetadataType, 2612 LLVMFunctionType, 2613 LLVMPointerType, 2614 LLVMFixedVectorType, 2615 LLVMScalableVectorType, 2616 LLVMArrayType, 2617 LLVMStructType>(); 2618 // clang-format on 2619 addOperations< 2620 #define GET_OP_LIST 2621 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2622 , 2623 #define GET_OP_LIST 2624 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 2625 >(); 2626 2627 // Support unknown operations because not all LLVM operations are registered. 2628 allowUnknownOperations(); 2629 } 2630 2631 #define GET_OP_CLASSES 2632 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2633 2634 /// Parse a type registered to this dialect. 2635 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2636 return detail::parseType(parser); 2637 } 2638 2639 /// Print a type registered to this dialect. 2640 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2641 return detail::printType(type, os); 2642 } 2643 2644 LogicalResult LLVMDialect::verifyDataLayoutString( 2645 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2646 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2647 llvm::DataLayout::parse(descr); 2648 if (maybeDataLayout) 2649 return success(); 2650 2651 std::string message; 2652 llvm::raw_string_ostream messageStream(message); 2653 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2654 reportError("invalid data layout descriptor: " + messageStream.str()); 2655 return failure(); 2656 } 2657 2658 /// Verify LLVM dialect attributes. 2659 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2660 NamedAttribute attr) { 2661 // If the `llvm.loop` attribute is present, enforce the following structure, 2662 // which the module translation can assume. 2663 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2664 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2665 if (!loopAttr) 2666 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2667 << "' to be a dictionary attribute"; 2668 Optional<NamedAttribute> parallelAccessGroup = 2669 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2670 if (parallelAccessGroup.hasValue()) { 2671 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2672 if (!accessGroups) 2673 return op->emitOpError() 2674 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2675 << "' to be an array attribute"; 2676 for (Attribute attr : accessGroups) { 2677 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2678 if (!accessGroupRef) 2679 return op->emitOpError() 2680 << "expected '" << attr << "' to be a symbol reference"; 2681 StringAttr metadataName = accessGroupRef.getRootReference(); 2682 auto metadataOp = 2683 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2684 op->getParentOp(), metadataName); 2685 if (!metadataOp) 2686 return op->emitOpError() 2687 << "expected '" << attr << "' to reference a metadata op"; 2688 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2689 Operation *accessGroupOp = 2690 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2691 if (!accessGroupOp) 2692 return op->emitOpError() 2693 << "expected '" << attr << "' to reference an access_group op"; 2694 } 2695 } 2696 2697 Optional<NamedAttribute> loopOptions = 2698 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2699 if (loopOptions.hasValue() && 2700 !loopOptions->getValue().isa<LoopOptionsAttr>()) 2701 return op->emitOpError() 2702 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2703 << "' to be a `loopopts` attribute"; 2704 } 2705 2706 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2707 return op->emitOpError() 2708 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2709 << "' is permitted only in argument or result attributes"; 2710 } 2711 2712 // If the data layout attribute is present, it must use the LLVM data layout 2713 // syntax. Try parsing it and report errors in case of failure. Users of this 2714 // attribute may assume it is well-formed and can pass it to the (asserting) 2715 // llvm::DataLayout constructor. 2716 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2717 return success(); 2718 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2719 return verifyDataLayoutString( 2720 stringAttr.getValue(), 2721 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2722 2723 return op->emitOpError() << "expected '" 2724 << LLVM::LLVMDialect::getDataLayoutAttrName() 2725 << "' to be a string attributes"; 2726 } 2727 2728 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2729 Type annotatedType) { 2730 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2731 if (!structType) { 2732 const auto emitIncorrectAnnotatedType = [&op]() { 2733 return op->emitError() 2734 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2735 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2736 }; 2737 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2738 if (!ptrType) 2739 return emitIncorrectAnnotatedType(); 2740 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2741 if (!structType) 2742 return emitIncorrectAnnotatedType(); 2743 } 2744 2745 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2746 if (!arrAttrs) 2747 return op->emitError() << "expected '" 2748 << LLVMDialect::getStructAttrsAttrName() 2749 << "' to be an array attribute"; 2750 2751 if (structType.getBody().size() != arrAttrs.size()) 2752 return op->emitError() 2753 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2754 << "' must match the size of the annotated '!llvm.struct'"; 2755 return success(); 2756 } 2757 2758 static LogicalResult verifyFuncOpInterfaceStructAttr( 2759 Operation *op, Attribute attr, 2760 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) { 2761 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2762 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2763 return op->emitError() << "expected '" 2764 << LLVMDialect::getStructAttrsAttrName() 2765 << "' to be used on function-like operations"; 2766 } 2767 2768 /// Verify LLVMIR function argument attributes. 2769 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2770 unsigned regionIdx, 2771 unsigned argIdx, 2772 NamedAttribute argAttr) { 2773 // Check that llvm.noalias is a unit attribute. 2774 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2775 !argAttr.getValue().isa<UnitAttr>()) 2776 return op->emitError() 2777 << "expected llvm.noalias argument attribute to be a unit attribute"; 2778 // Check that llvm.align is an integer attribute. 2779 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2780 !argAttr.getValue().isa<IntegerAttr>()) 2781 return op->emitError() 2782 << "llvm.align argument attribute of non integer type"; 2783 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2784 return verifyFuncOpInterfaceStructAttr( 2785 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2786 return funcOp.getArgumentTypes()[argIdx]; 2787 }); 2788 } 2789 return success(); 2790 } 2791 2792 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2793 unsigned regionIdx, 2794 unsigned resIdx, 2795 NamedAttribute resAttr) { 2796 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2797 return verifyFuncOpInterfaceStructAttr( 2798 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2799 return funcOp.getResultTypes()[resIdx]; 2800 }); 2801 } 2802 return success(); 2803 } 2804 2805 //===----------------------------------------------------------------------===// 2806 // Utility functions. 2807 //===----------------------------------------------------------------------===// 2808 2809 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2810 StringRef name, StringRef value, 2811 LLVM::Linkage linkage) { 2812 assert(builder.getInsertionBlock() && 2813 builder.getInsertionBlock()->getParentOp() && 2814 "expected builder to point to a block constrained in an op"); 2815 auto module = 2816 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2817 assert(module && "builder points to an op outside of a module"); 2818 2819 // Create the global at the entry of the module. 2820 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2821 MLIRContext *ctx = builder.getContext(); 2822 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2823 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2824 loc, type, /*isConstant=*/true, linkage, name, 2825 builder.getStringAttr(value), /*alignment=*/0); 2826 2827 // Get the pointer to the first character in the global string. 2828 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2829 Value cst0 = builder.create<LLVM::ConstantOp>( 2830 loc, IntegerType::get(ctx, 64), 2831 builder.getIntegerAttr(builder.getIndexType(), 0)); 2832 return builder.create<LLVM::GEPOp>( 2833 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2834 ValueRange{cst0, cst0}); 2835 } 2836 2837 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2838 return op->hasTrait<OpTrait::SymbolTable>() && 2839 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2840 } 2841 2842 static constexpr const FastmathFlags fastmathFlagsList[] = { 2843 // clang-format off 2844 FastmathFlags::nnan, 2845 FastmathFlags::ninf, 2846 FastmathFlags::nsz, 2847 FastmathFlags::arcp, 2848 FastmathFlags::contract, 2849 FastmathFlags::afn, 2850 FastmathFlags::reassoc, 2851 FastmathFlags::fast, 2852 // clang-format on 2853 }; 2854 2855 void FMFAttr::print(AsmPrinter &printer) const { 2856 printer << "<"; 2857 auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { 2858 return bitEnumContains(this->getFlags(), flag); 2859 }); 2860 llvm::interleaveComma(flags, printer, 2861 [&](auto flag) { printer << stringifyEnum(flag); }); 2862 printer << ">"; 2863 } 2864 2865 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2866 if (failed(parser.parseLess())) 2867 return {}; 2868 2869 FastmathFlags flags = {}; 2870 if (failed(parser.parseOptionalGreater())) { 2871 do { 2872 StringRef elemName; 2873 if (failed(parser.parseKeyword(&elemName))) 2874 return {}; 2875 2876 auto elem = symbolizeFastmathFlags(elemName); 2877 if (!elem) { 2878 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2879 << elemName; 2880 return {}; 2881 } 2882 2883 flags = flags | *elem; 2884 } while (succeeded(parser.parseOptionalComma())); 2885 2886 if (failed(parser.parseGreater())) 2887 return {}; 2888 } 2889 2890 return FMFAttr::get(parser.getContext(), flags); 2891 } 2892 2893 void LinkageAttr::print(AsmPrinter &printer) const { 2894 printer << "<"; 2895 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2896 printer << stringifyEnum(getLinkage()); 2897 else 2898 printer << static_cast<uint64_t>(getLinkage()); 2899 printer << ">"; 2900 } 2901 2902 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2903 StringRef elemName; 2904 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2905 parser.parseGreater()) 2906 return {}; 2907 auto elem = linkage::symbolizeLinkage(elemName); 2908 if (!elem) { 2909 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2910 return {}; 2911 } 2912 Linkage linkage = *elem; 2913 return LinkageAttr::get(parser.getContext(), linkage); 2914 } 2915 2916 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2917 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2918 2919 template <typename T> 2920 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2921 Optional<T> value) { 2922 auto option = llvm::find_if( 2923 options, [tag](auto option) { return option.first == tag; }); 2924 if (option != options.end()) { 2925 if (value.hasValue()) 2926 option->second = *value; 2927 else 2928 options.erase(option); 2929 } else { 2930 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2931 } 2932 return *this; 2933 } 2934 2935 LoopOptionsAttrBuilder & 2936 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 2937 return setOption(LoopOptionCase::disable_licm, value); 2938 } 2939 2940 /// Set the `interleave_count` option to the provided value. If no value 2941 /// is provided the option is deleted. 2942 LoopOptionsAttrBuilder & 2943 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 2944 return setOption(LoopOptionCase::interleave_count, count); 2945 } 2946 2947 /// Set the `disable_unroll` option to the provided value. If no value 2948 /// is provided the option is deleted. 2949 LoopOptionsAttrBuilder & 2950 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 2951 return setOption(LoopOptionCase::disable_unroll, value); 2952 } 2953 2954 /// Set the `disable_pipeline` option to the provided value. If no value 2955 /// is provided the option is deleted. 2956 LoopOptionsAttrBuilder & 2957 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 2958 return setOption(LoopOptionCase::disable_pipeline, value); 2959 } 2960 2961 /// Set the `pipeline_initiation_interval` option to the provided value. 2962 /// If no value is provided the option is deleted. 2963 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 2964 Optional<uint64_t> count) { 2965 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 2966 } 2967 2968 template <typename T> 2969 static Optional<T> 2970 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 2971 LoopOptionCase option) { 2972 auto it = 2973 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 2974 return optionPair.first < option; 2975 }); 2976 if (it == options.end()) 2977 return {}; 2978 return static_cast<T>(it->second); 2979 } 2980 2981 Optional<bool> LoopOptionsAttr::disableUnroll() { 2982 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 2983 } 2984 2985 Optional<bool> LoopOptionsAttr::disableLICM() { 2986 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 2987 } 2988 2989 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 2990 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 2991 } 2992 2993 /// Build the LoopOptions Attribute from a sorted array of individual options. 2994 LoopOptionsAttr LoopOptionsAttr::get( 2995 MLIRContext *context, 2996 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 2997 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 2998 "LoopOptionsAttr ctor expects a sorted options array"); 2999 return Base::get(context, sortedOptions); 3000 } 3001 3002 /// Build the LoopOptions Attribute from a sorted array of individual options. 3003 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 3004 LoopOptionsAttrBuilder &optionBuilders) { 3005 llvm::sort(optionBuilders.options, llvm::less_first()); 3006 return Base::get(context, optionBuilders.options); 3007 } 3008 3009 void LoopOptionsAttr::print(AsmPrinter &printer) const { 3010 printer << "<"; 3011 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 3012 printer << stringifyEnum(option.first) << " = "; 3013 switch (option.first) { 3014 case LoopOptionCase::disable_licm: 3015 case LoopOptionCase::disable_unroll: 3016 case LoopOptionCase::disable_pipeline: 3017 printer << (option.second ? "true" : "false"); 3018 break; 3019 case LoopOptionCase::interleave_count: 3020 case LoopOptionCase::pipeline_initiation_interval: 3021 printer << option.second; 3022 break; 3023 } 3024 }); 3025 printer << ">"; 3026 } 3027 3028 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 3029 if (failed(parser.parseLess())) 3030 return {}; 3031 3032 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 3033 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 3034 do { 3035 StringRef optionName; 3036 if (parser.parseKeyword(&optionName)) 3037 return {}; 3038 3039 auto option = symbolizeLoopOptionCase(optionName); 3040 if (!option) { 3041 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 3042 << optionName; 3043 return {}; 3044 } 3045 if (!seenOptions.insert(*option).second) { 3046 parser.emitError(parser.getNameLoc(), "loop option present twice"); 3047 return {}; 3048 } 3049 if (failed(parser.parseEqual())) 3050 return {}; 3051 3052 int64_t value; 3053 switch (*option) { 3054 case LoopOptionCase::disable_licm: 3055 case LoopOptionCase::disable_unroll: 3056 case LoopOptionCase::disable_pipeline: 3057 if (succeeded(parser.parseOptionalKeyword("true"))) 3058 value = 1; 3059 else if (succeeded(parser.parseOptionalKeyword("false"))) 3060 value = 0; 3061 else { 3062 parser.emitError(parser.getNameLoc(), 3063 "expected boolean value 'true' or 'false'"); 3064 return {}; 3065 } 3066 break; 3067 case LoopOptionCase::interleave_count: 3068 case LoopOptionCase::pipeline_initiation_interval: 3069 if (failed(parser.parseInteger(value))) { 3070 parser.emitError(parser.getNameLoc(), "expected integer value"); 3071 return {}; 3072 } 3073 break; 3074 } 3075 options.push_back(std::make_pair(*option, value)); 3076 } while (succeeded(parser.parseOptionalComma())); 3077 if (failed(parser.parseGreater())) 3078 return {}; 3079 3080 llvm::sort(options, llvm::less_first()); 3081 return get(parser.getContext(), options); 3082 } 3083