1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Mutex.h" 33 #include "llvm/Support/SourceMgr.h" 34 35 #include <numeric> 36 37 using namespace mlir; 38 using namespace mlir::LLVM; 39 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 40 41 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 42 43 static constexpr const char kVolatileAttrName[] = "volatile_"; 44 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 45 46 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 47 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 48 #define GET_ATTRDEF_CLASSES 49 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 50 51 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 52 SmallVector<NamedAttribute, 8> filteredAttrs( 53 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 54 if (attr.getName() == "fastmathFlags") { 55 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 56 return defAttr != attr.getValue(); 57 } 58 return true; 59 })); 60 return filteredAttrs; 61 } 62 63 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 64 NamedAttrList &result) { 65 return parser.parseOptionalAttrDict(result); 66 } 67 68 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 69 DictionaryAttr attrs) { 70 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 71 } 72 73 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 74 /// fully defined llvm.func. 75 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 76 Operation *op, 77 SymbolTableCollection &symbolTable) { 78 StringRef name = symbol.getValue(); 79 auto func = 80 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 81 if (!func) 82 return op->emitOpError("'") 83 << name << "' does not reference a valid LLVM function"; 84 if (func.isExternal()) 85 return op->emitOpError("'") << name << "' does not have a definition"; 86 return success(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Printing/parsing for LLVM::CmpOp. 91 //===----------------------------------------------------------------------===// 92 93 void ICmpOp::print(OpAsmPrinter &p) { 94 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 95 << ", " << getOperand(1); 96 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 97 p << " : " << getLhs().getType(); 98 } 99 100 void FCmpOp::print(OpAsmPrinter &p) { 101 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 102 << ", " << getOperand(1); 103 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 104 p << " : " << getLhs().getType(); 105 } 106 107 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 108 // attribute-dict? `:` type 109 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 110 // attribute-dict? `:` type 111 template <typename CmpPredicateType> 112 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 113 Builder &builder = parser.getBuilder(); 114 115 StringAttr predicateAttr; 116 OpAsmParser::OperandType lhs, rhs; 117 Type type; 118 SMLoc predicateLoc, trailingTypeLoc; 119 if (parser.getCurrentLocation(&predicateLoc) || 120 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 121 parser.parseOperand(lhs) || parser.parseComma() || 122 parser.parseOperand(rhs) || 123 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 124 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 125 parser.resolveOperand(lhs, type, result.operands) || 126 parser.resolveOperand(rhs, type, result.operands)) 127 return failure(); 128 129 // Replace the string attribute `predicate` with an integer attribute. 130 int64_t predicateValue = 0; 131 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 132 Optional<ICmpPredicate> predicate = 133 symbolizeICmpPredicate(predicateAttr.getValue()); 134 if (!predicate) 135 return parser.emitError(predicateLoc) 136 << "'" << predicateAttr.getValue() 137 << "' is an incorrect value of the 'predicate' attribute"; 138 predicateValue = static_cast<int64_t>(predicate.getValue()); 139 } else { 140 Optional<FCmpPredicate> predicate = 141 symbolizeFCmpPredicate(predicateAttr.getValue()); 142 if (!predicate) 143 return parser.emitError(predicateLoc) 144 << "'" << predicateAttr.getValue() 145 << "' is an incorrect value of the 'predicate' attribute"; 146 predicateValue = static_cast<int64_t>(predicate.getValue()); 147 } 148 149 result.attributes.set("predicate", 150 parser.getBuilder().getI64IntegerAttr(predicateValue)); 151 152 // The result type is either i1 or a vector type <? x i1> if the inputs are 153 // vectors. 154 Type resultType = IntegerType::get(builder.getContext(), 1); 155 if (!isCompatibleType(type)) 156 return parser.emitError(trailingTypeLoc, 157 "expected LLVM dialect-compatible type"); 158 if (LLVM::isCompatibleVectorType(type)) { 159 if (LLVM::isScalableVectorType(type)) { 160 resultType = LLVM::getVectorType( 161 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 162 /*isScalable=*/true); 163 } else { 164 resultType = LLVM::getVectorType( 165 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 166 /*isScalable=*/false); 167 } 168 } 169 170 result.addTypes({resultType}); 171 return success(); 172 } 173 174 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 175 return parseCmpOp<ICmpPredicate>(parser, result); 176 } 177 178 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 179 return parseCmpOp<FCmpPredicate>(parser, result); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // Printing/parsing for LLVM::AllocaOp. 184 //===----------------------------------------------------------------------===// 185 186 void AllocaOp::print(OpAsmPrinter &p) { 187 auto elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 188 189 auto funcTy = 190 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 191 192 p << ' ' << getArraySize() << " x " << elemTy; 193 if (getAlignment().hasValue() && *getAlignment() != 0) 194 p.printOptionalAttrDict((*this)->getAttrs()); 195 else 196 p.printOptionalAttrDict((*this)->getAttrs(), {"alignment"}); 197 p << " : " << funcTy; 198 } 199 200 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 201 // `:` type `,` type 202 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 203 OpAsmParser::OperandType arraySize; 204 Type type, elemType; 205 SMLoc trailingTypeLoc; 206 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 207 parser.parseType(elemType) || 208 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 209 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 210 return failure(); 211 212 Optional<NamedAttribute> alignmentAttr = 213 result.attributes.getNamed("alignment"); 214 if (alignmentAttr.hasValue()) { 215 auto alignmentInt = 216 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 217 if (!alignmentInt) 218 return parser.emitError(parser.getNameLoc(), 219 "expected integer alignment"); 220 if (alignmentInt.getValue().isNullValue()) 221 result.attributes.erase("alignment"); 222 } 223 224 // Extract the result type from the trailing function type. 225 auto funcType = type.dyn_cast<FunctionType>(); 226 if (!funcType || funcType.getNumInputs() != 1 || 227 funcType.getNumResults() != 1) 228 return parser.emitError( 229 trailingTypeLoc, 230 "expected trailing function type with one argument and one result"); 231 232 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 233 return failure(); 234 235 result.addTypes({funcType.getResult(0)}); 236 return success(); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // LLVM::BrOp 241 //===----------------------------------------------------------------------===// 242 243 Optional<MutableOperandRange> 244 BrOp::getMutableSuccessorOperands(unsigned index) { 245 assert(index == 0 && "invalid successor index"); 246 return getDestOperandsMutable(); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // LLVM::CondBrOp 251 //===----------------------------------------------------------------------===// 252 253 Optional<MutableOperandRange> 254 CondBrOp::getMutableSuccessorOperands(unsigned index) { 255 assert(index < getNumSuccessors() && "invalid successor index"); 256 return index == 0 ? getTrueDestOperandsMutable() 257 : getFalseDestOperandsMutable(); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // LLVM::SwitchOp 262 //===----------------------------------------------------------------------===// 263 264 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 265 Block *defaultDestination, ValueRange defaultOperands, 266 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 267 ArrayRef<ValueRange> caseOperands, 268 ArrayRef<int32_t> branchWeights) { 269 ElementsAttr caseValuesAttr; 270 if (!caseValues.empty()) 271 caseValuesAttr = builder.getI32VectorAttr(caseValues); 272 273 ElementsAttr weightsAttr; 274 if (!branchWeights.empty()) 275 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 276 277 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 278 weightsAttr, defaultDestination, caseDestinations); 279 } 280 281 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 282 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 283 static ParseResult parseSwitchOpCases( 284 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 285 SmallVectorImpl<Block *> &caseDestinations, 286 SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands, 287 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 288 SmallVector<APInt> values; 289 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 290 do { 291 int64_t value = 0; 292 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 293 if (values.empty() && !integerParseResult.hasValue()) 294 return success(); 295 296 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 297 return failure(); 298 values.push_back(APInt(bitWidth, value)); 299 300 Block *destination; 301 SmallVector<OpAsmParser::OperandType> operands; 302 SmallVector<Type> operandTypes; 303 if (parser.parseColon() || parser.parseSuccessor(destination)) 304 return failure(); 305 if (!parser.parseOptionalLParen()) { 306 if (parser.parseRegionArgumentList(operands) || 307 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 308 return failure(); 309 } 310 caseDestinations.push_back(destination); 311 caseOperands.emplace_back(operands); 312 caseOperandTypes.emplace_back(operandTypes); 313 } while (!parser.parseOptionalComma()); 314 315 ShapedType caseValueType = 316 VectorType::get(static_cast<int64_t>(values.size()), flagType); 317 caseValues = DenseIntElementsAttr::get(caseValueType, values); 318 return success(); 319 } 320 321 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 322 ElementsAttr caseValues, 323 SuccessorRange caseDestinations, 324 OperandRangeRange caseOperands, 325 const TypeRangeRange &caseOperandTypes) { 326 if (!caseValues) 327 return; 328 329 size_t index = 0; 330 llvm::interleave( 331 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 332 [&](auto i) { 333 p << " "; 334 p << std::get<0>(i).getLimitedValue(); 335 p << ": "; 336 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 337 }, 338 [&] { 339 p << ','; 340 p.printNewline(); 341 }); 342 p.printNewline(); 343 } 344 345 LogicalResult SwitchOp::verify() { 346 if ((!getCaseValues() && !getCaseDestinations().empty()) || 347 (getCaseValues() && 348 getCaseValues()->size() != 349 static_cast<int64_t>(getCaseDestinations().size()))) 350 return emitOpError("expects number of case values to match number of " 351 "case destinations"); 352 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 353 return emitError("expects number of branch weights to match number of " 354 "successors: ") 355 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 356 return success(); 357 } 358 359 Optional<MutableOperandRange> 360 SwitchOp::getMutableSuccessorOperands(unsigned index) { 361 assert(index < getNumSuccessors() && "invalid successor index"); 362 return index == 0 ? getDefaultOperandsMutable() 363 : getCaseOperandsMutable(index - 1); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // Code for LLVM::GEPOp. 368 //===----------------------------------------------------------------------===// 369 370 constexpr int GEPOp::kDynamicIndex; 371 372 /// Populates `indices` with positions of GEP indices that would correspond to 373 /// LLVMStructTypes potentially nested in the given type. The type currently 374 /// visited gets `currentIndex` and LLVM container types are visited 375 /// recursively. The recursion is bounded and takes care of recursive types by 376 /// means of the `visited` set. 377 static void recordStructIndices(Type type, unsigned currentIndex, 378 SmallVectorImpl<unsigned> &indices, 379 SmallVectorImpl<unsigned> *structSizes, 380 SmallPtrSet<Type, 4> &visited) { 381 if (visited.contains(type)) 382 return; 383 384 visited.insert(type); 385 386 llvm::TypeSwitch<Type>(type) 387 .Case<LLVMStructType>([&](LLVMStructType structType) { 388 indices.push_back(currentIndex); 389 if (structSizes) 390 structSizes->push_back(structType.getBody().size()); 391 for (Type elementType : structType.getBody()) 392 recordStructIndices(elementType, currentIndex + 1, indices, 393 structSizes, visited); 394 }) 395 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 396 LLVMArrayType>([&](auto containerType) { 397 recordStructIndices(containerType.getElementType(), currentIndex + 1, 398 indices, structSizes, visited); 399 }); 400 } 401 402 /// Populates `indices` with positions of GEP indices that correspond to 403 /// LLVMStructTypes potentially nested in the given `baseGEPType`, which must 404 /// be either an LLVMPointer type or a vector thereof. If `structSizes` is 405 /// provided, it is populated with sizes of the indexed structs for bounds 406 /// verification purposes. 407 static void 408 findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices, 409 SmallVectorImpl<unsigned> *structSizes = nullptr) { 410 Type type = baseGEPType; 411 if (auto vectorType = type.dyn_cast<VectorType>()) 412 type = vectorType.getElementType(); 413 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 414 type = scalableVectorType.getElementType(); 415 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 416 type = fixedVectorType.getElementType(); 417 418 Type pointeeType = type.cast<LLVMPointerType>().getElementType(); 419 SmallPtrSet<Type, 4> visited; 420 recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes, 421 visited); 422 } 423 424 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 425 Value basePtr, ValueRange operands, 426 ArrayRef<NamedAttribute> attributes) { 427 build(builder, result, resultType, basePtr, operands, 428 SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex), 429 attributes); 430 } 431 432 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 433 Value basePtr, ValueRange indices, 434 ArrayRef<int32_t> structIndices, 435 ArrayRef<NamedAttribute> attributes) { 436 SmallVector<Value> remainingIndices; 437 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 438 structIndices.end()); 439 SmallVector<unsigned> structRelatedPositions; 440 findKnownStructIndices(basePtr.getType(), structRelatedPositions); 441 442 SmallVector<unsigned> operandsToErase; 443 for (unsigned pos : structRelatedPositions) { 444 // GEP may not be indexing as deep as some structs are located. 445 if (pos >= structIndices.size()) 446 continue; 447 448 // If the index is already static, it's fine. 449 if (structIndices[pos] != kDynamicIndex) 450 continue; 451 452 // Find the corresponding operand. 453 unsigned operandPos = 454 std::count(structIndices.begin(), std::next(structIndices.begin(), pos), 455 kDynamicIndex); 456 457 // Extract the constant value from the operand and put it into the attribute 458 // instead. 459 APInt staticIndexValue; 460 bool matched = 461 matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue)); 462 (void)matched; 463 assert(matched && "index into a struct must be a constant"); 464 assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) && 465 "struct index underflows 32-bit integer"); 466 assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) && 467 "struct index overflows 32-bit integer"); 468 auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 469 updatedStructIndices[pos] = staticIndex; 470 operandsToErase.push_back(operandPos); 471 } 472 473 for (unsigned i = 0, e = indices.size(); i < e; ++i) { 474 if (!llvm::is_contained(operandsToErase, i)) 475 remainingIndices.push_back(indices[i]); 476 } 477 478 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 479 updatedStructIndices, kDynamicIndex)) && 480 "expected as many index operands as dynamic index attr elements"); 481 482 result.addTypes(resultType); 483 result.addAttributes(attributes); 484 result.addAttribute("structIndices", 485 builder.getI32TensorAttr(updatedStructIndices)); 486 result.addOperands(basePtr); 487 result.addOperands(remainingIndices); 488 } 489 490 static ParseResult 491 parseGEPIndices(OpAsmParser &parser, 492 SmallVectorImpl<OpAsmParser::OperandType> &indices, 493 DenseIntElementsAttr &structIndices) { 494 SmallVector<int32_t> constantIndices; 495 do { 496 int32_t constantIndex; 497 OptionalParseResult parsedInteger = 498 parser.parseOptionalInteger(constantIndex); 499 if (parsedInteger.hasValue()) { 500 if (failed(parsedInteger.getValue())) 501 return failure(); 502 constantIndices.push_back(constantIndex); 503 continue; 504 } 505 506 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 507 if (failed(parser.parseOperand(indices.emplace_back()))) 508 return failure(); 509 } while (succeeded(parser.parseOptionalComma())); 510 511 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 512 return success(); 513 } 514 515 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 516 OperandRange indices, 517 DenseIntElementsAttr structIndices) { 518 unsigned operandIdx = 0; 519 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 520 [&](int32_t cst) { 521 if (cst == LLVM::GEPOp::kDynamicIndex) 522 printer.printOperand(indices[operandIdx++]); 523 else 524 printer << cst; 525 }); 526 } 527 528 LogicalResult LLVM::GEPOp::verify() { 529 SmallVector<unsigned> indices; 530 SmallVector<unsigned> structSizes; 531 findKnownStructIndices(getBase().getType(), indices, &structSizes); 532 DenseIntElementsAttr structIndices = getStructIndices(); 533 for (unsigned i : llvm::seq<unsigned>(0, indices.size())) { 534 unsigned index = indices[i]; 535 // GEP may not be indexing as deep as some structs nested in the type. 536 if (index >= structIndices.getNumElements()) 537 continue; 538 539 int32_t staticIndex = structIndices.getValues<int32_t>()[index]; 540 if (staticIndex == LLVM::GEPOp::kDynamicIndex) 541 return emitOpError() << "expected index " << index 542 << " indexing a struct to be constant"; 543 if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i]) 544 return emitOpError() << "index " << index 545 << " indexing a struct is out of bounds"; 546 } 547 return success(); 548 } 549 550 //===----------------------------------------------------------------------===// 551 // Builder, printer and parser for for LLVM::LoadOp. 552 //===----------------------------------------------------------------------===// 553 554 LogicalResult verifySymbolAttribute( 555 Operation *op, StringRef attributeName, 556 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 557 verifySymbolType) { 558 if (Attribute attribute = op->getAttr(attributeName)) { 559 // The attribute is already verified to be a symbol ref array attribute via 560 // a constraint in the operation definition. 561 for (SymbolRefAttr symbolRef : 562 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 563 StringAttr metadataName = symbolRef.getRootReference(); 564 StringAttr symbolName = symbolRef.getLeafReference(); 565 // We want @metadata::@symbol, not just @symbol 566 if (metadataName == symbolName) { 567 return op->emitOpError() << "expected '" << symbolRef 568 << "' to specify a fully qualified reference"; 569 } 570 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 571 op->getParentOp(), metadataName); 572 if (!metadataOp) 573 return op->emitOpError() 574 << "expected '" << symbolRef << "' to reference a metadata op"; 575 Operation *symbolOp = 576 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 577 if (!symbolOp) 578 return op->emitOpError() 579 << "expected '" << symbolRef << "' to be a valid reference"; 580 if (failed(verifySymbolType(symbolOp, symbolRef))) { 581 return failure(); 582 } 583 } 584 } 585 return success(); 586 } 587 588 // Verifies that metadata ops are wired up properly. 589 template <typename OpTy> 590 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 591 auto verifySymbolType = [op](Operation *symbolOp, 592 SymbolRefAttr symbolRef) -> LogicalResult { 593 if (!isa<OpTy>(symbolOp)) { 594 return op->emitOpError() 595 << "expected '" << symbolRef << "' to resolve to a " 596 << OpTy::getOperationName(); 597 } 598 return success(); 599 }; 600 601 return verifySymbolAttribute(op, attributeName, verifySymbolType); 602 } 603 604 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 605 // access_groups 606 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 607 op, LLVMDialect::getAccessGroupsAttrName()))) 608 return failure(); 609 610 // alias_scopes 611 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 612 op, LLVMDialect::getAliasScopesAttrName()))) 613 return failure(); 614 615 // noalias_scopes 616 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 617 op, LLVMDialect::getNoAliasScopesAttrName()))) 618 return failure(); 619 620 return success(); 621 } 622 623 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 624 625 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 626 Value addr, unsigned alignment, bool isVolatile, 627 bool isNonTemporal) { 628 result.addOperands(addr); 629 result.addTypes(t); 630 if (isVolatile) 631 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 632 if (isNonTemporal) 633 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 634 if (alignment != 0) 635 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 636 } 637 638 void LoadOp::print(OpAsmPrinter &p) { 639 p << ' '; 640 if (getVolatile_()) 641 p << "volatile "; 642 p << getAddr(); 643 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 644 p << " : " << getAddr().getType(); 645 } 646 647 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 648 // the resulting type wrapped in MLIR, or nullptr on error. 649 static Type getLoadStoreElementType(OpAsmParser &parser, Type type, 650 SMLoc trailingTypeLoc) { 651 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 652 if (!llvmTy) 653 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), 654 nullptr; 655 return llvmTy.getElementType(); 656 } 657 658 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 659 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 660 OpAsmParser::OperandType addr; 661 Type type; 662 SMLoc trailingTypeLoc; 663 664 if (succeeded(parser.parseOptionalKeyword("volatile"))) 665 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 666 667 if (parser.parseOperand(addr) || 668 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 669 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 670 parser.resolveOperand(addr, type, result.operands)) 671 return failure(); 672 673 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 674 675 result.addTypes(elemTy); 676 return success(); 677 } 678 679 //===----------------------------------------------------------------------===// 680 // Builder, printer and parser for LLVM::StoreOp. 681 //===----------------------------------------------------------------------===// 682 683 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 684 685 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 686 Value addr, unsigned alignment, bool isVolatile, 687 bool isNonTemporal) { 688 result.addOperands({value, addr}); 689 result.addTypes({}); 690 if (isVolatile) 691 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 692 if (isNonTemporal) 693 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 694 if (alignment != 0) 695 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 696 } 697 698 void StoreOp::print(OpAsmPrinter &p) { 699 p << ' '; 700 if (getVolatile_()) 701 p << "volatile "; 702 p << getValue() << ", " << getAddr(); 703 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 704 p << " : " << getAddr().getType(); 705 } 706 707 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 708 // attribute-dict? `:` type 709 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 710 OpAsmParser::OperandType addr, value; 711 Type type; 712 SMLoc trailingTypeLoc; 713 714 if (succeeded(parser.parseOptionalKeyword("volatile"))) 715 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 716 717 if (parser.parseOperand(value) || parser.parseComma() || 718 parser.parseOperand(addr) || 719 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 720 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 721 return failure(); 722 723 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 724 if (!elemTy) 725 return failure(); 726 727 if (parser.resolveOperand(value, elemTy, result.operands) || 728 parser.resolveOperand(addr, type, result.operands)) 729 return failure(); 730 731 return success(); 732 } 733 734 ///===---------------------------------------------------------------------===// 735 /// LLVM::InvokeOp 736 ///===---------------------------------------------------------------------===// 737 738 Optional<MutableOperandRange> 739 InvokeOp::getMutableSuccessorOperands(unsigned index) { 740 assert(index < getNumSuccessors() && "invalid successor index"); 741 return index == 0 ? getNormalDestOperandsMutable() 742 : getUnwindDestOperandsMutable(); 743 } 744 745 LogicalResult InvokeOp::verify() { 746 if (getNumResults() > 1) 747 return emitOpError("must have 0 or 1 result"); 748 749 Block *unwindDest = getUnwindDest(); 750 if (unwindDest->empty()) 751 return emitError("must have at least one operation in unwind destination"); 752 753 // In unwind destination, first operation must be LandingpadOp 754 if (!isa<LandingpadOp>(unwindDest->front())) 755 return emitError("first operation in unwind destination should be a " 756 "llvm.landingpad operation"); 757 758 return success(); 759 } 760 761 void InvokeOp::print(OpAsmPrinter &p) { 762 auto callee = getCallee(); 763 bool isDirect = callee.hasValue(); 764 765 p << ' '; 766 767 // Either function name or pointer 768 if (isDirect) 769 p.printSymbolName(callee.getValue()); 770 else 771 p << getOperand(0); 772 773 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 774 p << " to "; 775 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 776 p << " unwind "; 777 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 778 779 p.printOptionalAttrDict((*this)->getAttrs(), 780 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 781 p << " : "; 782 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 783 getResultTypes()); 784 } 785 786 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 787 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 788 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 789 /// attribute-dict? `:` function-type 790 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 791 SmallVector<OpAsmParser::OperandType, 8> operands; 792 FunctionType funcType; 793 SymbolRefAttr funcAttr; 794 SMLoc trailingTypeLoc; 795 Block *normalDest, *unwindDest; 796 SmallVector<Value, 4> normalOperands, unwindOperands; 797 Builder &builder = parser.getBuilder(); 798 799 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 800 // case of an indirect call, there will be 1 operand before `(`. In case of a 801 // direct call, there will be no operands and the parser will stop at the 802 // function identifier without complaining. 803 if (parser.parseOperandList(operands)) 804 return failure(); 805 bool isDirect = operands.empty(); 806 807 // Optionally parse a function identifier. 808 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 809 return failure(); 810 811 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 812 parser.parseKeyword("to") || 813 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 814 parser.parseKeyword("unwind") || 815 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 816 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 817 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 818 return failure(); 819 820 if (isDirect) { 821 // Make sure types match. 822 if (parser.resolveOperands(operands, funcType.getInputs(), 823 parser.getNameLoc(), result.operands)) 824 return failure(); 825 result.addTypes(funcType.getResults()); 826 } else { 827 // Construct the LLVM IR Dialect function type that the first operand 828 // should match. 829 if (funcType.getNumResults() > 1) 830 return parser.emitError(trailingTypeLoc, 831 "expected function with 0 or 1 result"); 832 833 Type llvmResultType; 834 if (funcType.getNumResults() == 0) { 835 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 836 } else { 837 llvmResultType = funcType.getResult(0); 838 if (!isCompatibleType(llvmResultType)) 839 return parser.emitError(trailingTypeLoc, 840 "expected result to have LLVM type"); 841 } 842 843 SmallVector<Type, 8> argTypes; 844 argTypes.reserve(funcType.getNumInputs()); 845 for (Type ty : funcType.getInputs()) { 846 if (isCompatibleType(ty)) 847 argTypes.push_back(ty); 848 else 849 return parser.emitError(trailingTypeLoc, 850 "expected LLVM types as inputs"); 851 } 852 853 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 854 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 855 856 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 857 858 // Make sure that the first operand (indirect callee) matches the wrapped 859 // LLVM IR function type, and that the types of the other call operands 860 // match the types of the function arguments. 861 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 862 parser.resolveOperands(funcArguments, funcType.getInputs(), 863 parser.getNameLoc(), result.operands)) 864 return failure(); 865 866 result.addTypes(llvmResultType); 867 } 868 result.addSuccessors({normalDest, unwindDest}); 869 result.addOperands(normalOperands); 870 result.addOperands(unwindOperands); 871 872 result.addAttribute( 873 InvokeOp::getOperandSegmentSizeAttr(), 874 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 875 static_cast<int32_t>(normalOperands.size()), 876 static_cast<int32_t>(unwindOperands.size())})); 877 return success(); 878 } 879 880 ///===----------------------------------------------------------------------===// 881 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 882 ///===----------------------------------------------------------------------===// 883 884 LogicalResult LandingpadOp::verify() { 885 Value value; 886 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 887 if (!func.getPersonality().hasValue()) 888 return emitError( 889 "llvm.landingpad needs to be in a function with a personality"); 890 } 891 892 if (!getCleanup() && getOperands().empty()) 893 return emitError("landingpad instruction expects at least one clause or " 894 "cleanup attribute"); 895 896 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 897 value = getOperand(idx); 898 bool isFilter = value.getType().isa<LLVMArrayType>(); 899 if (isFilter) { 900 // FIXME: Verify filter clauses when arrays are appropriately handled 901 } else { 902 // catch - global addresses only. 903 // Bitcast ops should have global addresses as their args. 904 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 905 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 906 continue; 907 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 908 << "global addresses expected as operand to " 909 "bitcast used in clauses for landingpad"; 910 } 911 // NullOp and AddressOfOp allowed 912 if (value.getDefiningOp<NullOp>()) 913 continue; 914 if (value.getDefiningOp<AddressOfOp>()) 915 continue; 916 return emitError("clause #") 917 << idx << " is not a known constant - null, addressof, bitcast"; 918 } 919 } 920 return success(); 921 } 922 923 void LandingpadOp::print(OpAsmPrinter &p) { 924 p << (getCleanup() ? " cleanup " : " "); 925 926 // Clauses 927 for (auto value : getOperands()) { 928 // Similar to llvm - if clause is an array type then it is filter 929 // clause else catch clause 930 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 931 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 932 << value.getType() << ") "; 933 } 934 935 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 936 937 p << ": " << getType(); 938 } 939 940 /// <operation> ::= `llvm.landingpad` `cleanup`? 941 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 942 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 943 // Check for cleanup 944 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 945 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 946 947 // Parse clauses with types 948 while (succeeded(parser.parseOptionalLParen()) && 949 (succeeded(parser.parseOptionalKeyword("filter")) || 950 succeeded(parser.parseOptionalKeyword("catch")))) { 951 OpAsmParser::OperandType operand; 952 Type ty; 953 if (parser.parseOperand(operand) || parser.parseColon() || 954 parser.parseType(ty) || 955 parser.resolveOperand(operand, ty, result.operands) || 956 parser.parseRParen()) 957 return failure(); 958 } 959 960 Type type; 961 if (parser.parseColon() || parser.parseType(type)) 962 return failure(); 963 964 result.addTypes(type); 965 return success(); 966 } 967 968 //===----------------------------------------------------------------------===// 969 // Verifying/Printing/parsing for LLVM::CallOp. 970 //===----------------------------------------------------------------------===// 971 972 LogicalResult CallOp::verify() { 973 if (getNumResults() > 1) 974 return emitOpError("must have 0 or 1 result"); 975 976 // Type for the callee, we'll get it differently depending if it is a direct 977 // or indirect call. 978 Type fnType; 979 980 bool isIndirect = false; 981 982 // If this is an indirect call, the callee attribute is missing. 983 FlatSymbolRefAttr calleeName = getCalleeAttr(); 984 if (!calleeName) { 985 isIndirect = true; 986 if (!getNumOperands()) 987 return emitOpError( 988 "must have either a `callee` attribute or at least an operand"); 989 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 990 if (!ptrType) 991 return emitOpError("indirect call expects a pointer as callee: ") 992 << ptrType; 993 fnType = ptrType.getElementType(); 994 } else { 995 Operation *callee = 996 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 997 if (!callee) 998 return emitOpError() 999 << "'" << calleeName.getValue() 1000 << "' does not reference a symbol in the current scope"; 1001 auto fn = dyn_cast<LLVMFuncOp>(callee); 1002 if (!fn) 1003 return emitOpError() << "'" << calleeName.getValue() 1004 << "' does not reference a valid LLVM function"; 1005 1006 fnType = fn.getType(); 1007 } 1008 1009 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1010 if (!funcType) 1011 return emitOpError("callee does not have a functional type: ") << fnType; 1012 1013 // Verify that the operand and result types match the callee. 1014 1015 if (!funcType.isVarArg() && 1016 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1017 return emitOpError() << "incorrect number of operands (" 1018 << (getNumOperands() - isIndirect) 1019 << ") for callee (expecting: " 1020 << funcType.getNumParams() << ")"; 1021 1022 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1023 return emitOpError() << "incorrect number of operands (" 1024 << (getNumOperands() - isIndirect) 1025 << ") for varargs callee (expecting at least: " 1026 << funcType.getNumParams() << ")"; 1027 1028 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1029 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1030 return emitOpError() << "operand type mismatch for operand " << i << ": " 1031 << getOperand(i + isIndirect).getType() 1032 << " != " << funcType.getParamType(i); 1033 1034 if (getNumResults() == 0 && 1035 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1036 return emitOpError() << "expected function call to produce a value"; 1037 1038 if (getNumResults() != 0 && 1039 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1040 return emitOpError() 1041 << "calling function with void result must not produce values"; 1042 1043 if (getNumResults() > 1) 1044 return emitOpError() 1045 << "expected LLVM function call to produce 0 or 1 result"; 1046 1047 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1048 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1049 << " != " << funcType.getReturnType(); 1050 1051 return success(); 1052 } 1053 1054 void CallOp::print(OpAsmPrinter &p) { 1055 auto callee = getCallee(); 1056 bool isDirect = callee.hasValue(); 1057 1058 // Print the direct callee if present as a function attribute, or an indirect 1059 // callee (first operand) otherwise. 1060 p << ' '; 1061 if (isDirect) 1062 p.printSymbolName(callee.getValue()); 1063 else 1064 p << getOperand(0); 1065 1066 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1067 p << '(' << args << ')'; 1068 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1069 1070 // Reconstruct the function MLIR function type from operand and result types. 1071 p << " : "; 1072 p.printFunctionalType(args.getTypes(), getResultTypes()); 1073 } 1074 1075 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1076 // attribute-dict? `:` function-type 1077 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1078 SmallVector<OpAsmParser::OperandType, 8> operands; 1079 Type type; 1080 SymbolRefAttr funcAttr; 1081 SMLoc trailingTypeLoc; 1082 1083 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1084 // case of an indirect call, there will be 1 operand before `(`. In case of a 1085 // direct call, there will be no operands and the parser will stop at the 1086 // function identifier without complaining. 1087 if (parser.parseOperandList(operands)) 1088 return failure(); 1089 bool isDirect = operands.empty(); 1090 1091 // Optionally parse a function identifier. 1092 if (isDirect) 1093 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1094 return failure(); 1095 1096 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1097 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1098 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1099 return failure(); 1100 1101 auto funcType = type.dyn_cast<FunctionType>(); 1102 if (!funcType) 1103 return parser.emitError(trailingTypeLoc, "expected function type"); 1104 if (funcType.getNumResults() > 1) 1105 return parser.emitError(trailingTypeLoc, 1106 "expected function with 0 or 1 result"); 1107 if (isDirect) { 1108 // Make sure types match. 1109 if (parser.resolveOperands(operands, funcType.getInputs(), 1110 parser.getNameLoc(), result.operands)) 1111 return failure(); 1112 if (funcType.getNumResults() != 0 && 1113 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1114 result.addTypes(funcType.getResults()); 1115 } else { 1116 Builder &builder = parser.getBuilder(); 1117 Type llvmResultType; 1118 if (funcType.getNumResults() == 0) { 1119 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1120 } else { 1121 llvmResultType = funcType.getResult(0); 1122 if (!isCompatibleType(llvmResultType)) 1123 return parser.emitError(trailingTypeLoc, 1124 "expected result to have LLVM type"); 1125 } 1126 1127 SmallVector<Type, 8> argTypes; 1128 argTypes.reserve(funcType.getNumInputs()); 1129 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1130 auto argType = funcType.getInput(i); 1131 if (!isCompatibleType(argType)) 1132 return parser.emitError(trailingTypeLoc, 1133 "expected LLVM types as inputs"); 1134 argTypes.push_back(argType); 1135 } 1136 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1137 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1138 1139 auto funcArguments = 1140 ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); 1141 1142 // Make sure that the first operand (indirect callee) matches the wrapped 1143 // LLVM IR function type, and that the types of the other call operands 1144 // match the types of the function arguments. 1145 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1146 parser.resolveOperands(funcArguments, funcType.getInputs(), 1147 parser.getNameLoc(), result.operands)) 1148 return failure(); 1149 1150 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1151 result.addTypes(llvmResultType); 1152 } 1153 1154 return success(); 1155 } 1156 1157 //===----------------------------------------------------------------------===// 1158 // Printing/parsing for LLVM::ExtractElementOp. 1159 //===----------------------------------------------------------------------===// 1160 // Expects vector to be of wrapped LLVM vector type and position to be of 1161 // wrapped LLVM i32 type. 1162 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1163 Value vector, Value position, 1164 ArrayRef<NamedAttribute> attrs) { 1165 auto vectorType = vector.getType(); 1166 auto llvmType = LLVM::getVectorElementType(vectorType); 1167 build(b, result, llvmType, vector, position); 1168 result.addAttributes(attrs); 1169 } 1170 1171 void ExtractElementOp::print(OpAsmPrinter &p) { 1172 p << ' ' << getVector() << "[" << getPosition() << " : " 1173 << getPosition().getType() << "]"; 1174 p.printOptionalAttrDict((*this)->getAttrs()); 1175 p << " : " << getVector().getType(); 1176 } 1177 1178 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1179 // attribute-dict? `:` type 1180 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1181 OperationState &result) { 1182 SMLoc loc; 1183 OpAsmParser::OperandType vector, position; 1184 Type type, positionType; 1185 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1186 parser.parseLSquare() || parser.parseOperand(position) || 1187 parser.parseColonType(positionType) || parser.parseRSquare() || 1188 parser.parseOptionalAttrDict(result.attributes) || 1189 parser.parseColonType(type) || 1190 parser.resolveOperand(vector, type, result.operands) || 1191 parser.resolveOperand(position, positionType, result.operands)) 1192 return failure(); 1193 if (!LLVM::isCompatibleVectorType(type)) 1194 return parser.emitError( 1195 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1196 result.addTypes(LLVM::getVectorElementType(type)); 1197 return success(); 1198 } 1199 1200 LogicalResult ExtractElementOp::verify() { 1201 Type vectorType = getVector().getType(); 1202 if (!LLVM::isCompatibleVectorType(vectorType)) 1203 return emitOpError("expected LLVM dialect-compatible vector type for " 1204 "operand #1, got") 1205 << vectorType; 1206 Type valueType = LLVM::getVectorElementType(vectorType); 1207 if (valueType != getRes().getType()) 1208 return emitOpError() << "Type mismatch: extracting from " << vectorType 1209 << " should produce " << valueType 1210 << " but this op returns " << getRes().getType(); 1211 return success(); 1212 } 1213 1214 //===----------------------------------------------------------------------===// 1215 // Printing/parsing for LLVM::ExtractValueOp. 1216 //===----------------------------------------------------------------------===// 1217 1218 void ExtractValueOp::print(OpAsmPrinter &p) { 1219 p << ' ' << getContainer() << getPosition(); 1220 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1221 p << " : " << getContainer().getType(); 1222 } 1223 1224 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1225 // `containerType`. Position is an integer array attribute where each value 1226 // is a zero-based position of the element in the aggregate type. Return the 1227 // resulting type wrapped in MLIR, or nullptr on error. 1228 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1229 Type containerType, 1230 ArrayAttr positionAttr, 1231 SMLoc attributeLoc, 1232 SMLoc typeLoc) { 1233 Type llvmType = containerType; 1234 if (!isCompatibleType(containerType)) 1235 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1236 1237 // Infer the element type from the structure type: iteratively step inside the 1238 // type by taking the element type, indexed by the position attribute for 1239 // structures. Check the position index before accessing, it is supposed to 1240 // be in bounds. 1241 for (Attribute subAttr : positionAttr) { 1242 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1243 if (!positionElementAttr) 1244 return parser.emitError(attributeLoc, 1245 "expected an array of integer literals"), 1246 nullptr; 1247 int position = positionElementAttr.getInt(); 1248 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1249 if (position < 0 || 1250 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1251 return parser.emitError(attributeLoc, "position out of bounds"), 1252 nullptr; 1253 llvmType = arrayType.getElementType(); 1254 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1255 if (position < 0 || 1256 static_cast<unsigned>(position) >= structType.getBody().size()) 1257 return parser.emitError(attributeLoc, "position out of bounds"), 1258 nullptr; 1259 llvmType = structType.getBody()[position]; 1260 } else { 1261 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1262 nullptr; 1263 } 1264 } 1265 return llvmType; 1266 } 1267 1268 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1269 // `containerType`. Returns null on failure. 1270 static Type getInsertExtractValueElementType(Type containerType, 1271 ArrayAttr positionAttr, 1272 Operation *op) { 1273 Type llvmType = containerType; 1274 if (!isCompatibleType(containerType)) { 1275 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1276 return {}; 1277 } 1278 1279 // Infer the element type from the structure type: iteratively step inside the 1280 // type by taking the element type, indexed by the position attribute for 1281 // structures. Check the position index before accessing, it is supposed to 1282 // be in bounds. 1283 for (Attribute subAttr : positionAttr) { 1284 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1285 if (!positionElementAttr) { 1286 op->emitOpError("expected an array of integer literals, got: ") 1287 << subAttr; 1288 return {}; 1289 } 1290 int position = positionElementAttr.getInt(); 1291 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1292 if (position < 0 || 1293 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1294 op->emitOpError("position out of bounds: ") << position; 1295 return {}; 1296 } 1297 llvmType = arrayType.getElementType(); 1298 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1299 if (position < 0 || 1300 static_cast<unsigned>(position) >= structType.getBody().size()) { 1301 op->emitOpError("position out of bounds") << position; 1302 return {}; 1303 } 1304 llvmType = structType.getBody()[position]; 1305 } else { 1306 op->emitOpError("expected LLVM IR structure/array type, got: ") 1307 << llvmType; 1308 return {}; 1309 } 1310 } 1311 return llvmType; 1312 } 1313 1314 // <operation> ::= `llvm.extractvalue` ssa-use 1315 // `[` integer-literal (`,` integer-literal)* `]` 1316 // attribute-dict? `:` type 1317 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1318 OpAsmParser::OperandType container; 1319 Type containerType; 1320 ArrayAttr positionAttr; 1321 SMLoc attributeLoc, trailingTypeLoc; 1322 1323 if (parser.parseOperand(container) || 1324 parser.getCurrentLocation(&attributeLoc) || 1325 parser.parseAttribute(positionAttr, "position", result.attributes) || 1326 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1327 parser.getCurrentLocation(&trailingTypeLoc) || 1328 parser.parseType(containerType) || 1329 parser.resolveOperand(container, containerType, result.operands)) 1330 return failure(); 1331 1332 auto elementType = getInsertExtractValueElementType( 1333 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1334 if (!elementType) 1335 return failure(); 1336 1337 result.addTypes(elementType); 1338 return success(); 1339 } 1340 1341 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1342 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1343 OpFoldResult result = {}; 1344 while (insertValueOp) { 1345 if (getPosition() == insertValueOp.getPosition()) 1346 return insertValueOp.getValue(); 1347 unsigned min = 1348 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1349 // If one is fully prefix of the other, stop propagating back as it will 1350 // miss dependencies. For instance, %3 should not fold to %f0 in the 1351 // following example: 1352 // ``` 1353 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1354 // !llvm.array<4 x !llvm.array<4xf32>> 1355 // %2 = llvm.insertvalue %arr, %1[0] : 1356 // !llvm.array<4 x !llvm.array<4xf32>> 1357 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1358 // ``` 1359 if (getPosition().getValue().take_front(min) == 1360 insertValueOp.getPosition().getValue().take_front(min)) 1361 return result; 1362 1363 // If neither a prefix, nor the exact position, we can extract out of the 1364 // value being inserted into. Moreover, we can try again if that operand 1365 // is itself an insertvalue expression. 1366 getContainerMutable().assign(insertValueOp.getContainer()); 1367 result = getResult(); 1368 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1369 } 1370 return result; 1371 } 1372 1373 LogicalResult ExtractValueOp::verify() { 1374 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1375 getPositionAttr(), *this); 1376 if (!valueType) 1377 return failure(); 1378 1379 if (getRes().getType() != valueType) 1380 return emitOpError() << "Type mismatch: extracting from " 1381 << getContainer().getType() << " should produce " 1382 << valueType << " but this op returns " 1383 << getRes().getType(); 1384 return success(); 1385 } 1386 1387 //===----------------------------------------------------------------------===// 1388 // Printing/parsing for LLVM::InsertElementOp. 1389 //===----------------------------------------------------------------------===// 1390 1391 void InsertElementOp::print(OpAsmPrinter &p) { 1392 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1393 << getPosition().getType() << "]"; 1394 p.printOptionalAttrDict((*this)->getAttrs()); 1395 p << " : " << getVector().getType(); 1396 } 1397 1398 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1399 // attribute-dict? `:` type 1400 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1401 OperationState &result) { 1402 SMLoc loc; 1403 OpAsmParser::OperandType vector, value, position; 1404 Type vectorType, positionType; 1405 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1406 parser.parseComma() || parser.parseOperand(vector) || 1407 parser.parseLSquare() || parser.parseOperand(position) || 1408 parser.parseColonType(positionType) || parser.parseRSquare() || 1409 parser.parseOptionalAttrDict(result.attributes) || 1410 parser.parseColonType(vectorType)) 1411 return failure(); 1412 1413 if (!LLVM::isCompatibleVectorType(vectorType)) 1414 return parser.emitError( 1415 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1416 Type valueType = LLVM::getVectorElementType(vectorType); 1417 if (!valueType) 1418 return failure(); 1419 1420 if (parser.resolveOperand(vector, vectorType, result.operands) || 1421 parser.resolveOperand(value, valueType, result.operands) || 1422 parser.resolveOperand(position, positionType, result.operands)) 1423 return failure(); 1424 1425 result.addTypes(vectorType); 1426 return success(); 1427 } 1428 1429 LogicalResult InsertElementOp::verify() { 1430 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1431 if (valueType != getValue().getType()) 1432 return emitOpError() << "Type mismatch: cannot insert " 1433 << getValue().getType() << " into " 1434 << getVector().getType(); 1435 return success(); 1436 } 1437 1438 //===----------------------------------------------------------------------===// 1439 // Printing/parsing for LLVM::InsertValueOp. 1440 //===----------------------------------------------------------------------===// 1441 1442 void InsertValueOp::print(OpAsmPrinter &p) { 1443 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1444 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1445 p << " : " << getContainer().getType(); 1446 } 1447 1448 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1449 // `[` integer-literal (`,` integer-literal)* `]` 1450 // attribute-dict? `:` type 1451 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1452 OpAsmParser::OperandType container, value; 1453 Type containerType; 1454 ArrayAttr positionAttr; 1455 SMLoc attributeLoc, trailingTypeLoc; 1456 1457 if (parser.parseOperand(value) || parser.parseComma() || 1458 parser.parseOperand(container) || 1459 parser.getCurrentLocation(&attributeLoc) || 1460 parser.parseAttribute(positionAttr, "position", result.attributes) || 1461 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1462 parser.getCurrentLocation(&trailingTypeLoc) || 1463 parser.parseType(containerType)) 1464 return failure(); 1465 1466 auto valueType = getInsertExtractValueElementType( 1467 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1468 if (!valueType) 1469 return failure(); 1470 1471 if (parser.resolveOperand(container, containerType, result.operands) || 1472 parser.resolveOperand(value, valueType, result.operands)) 1473 return failure(); 1474 1475 result.addTypes(containerType); 1476 return success(); 1477 } 1478 1479 LogicalResult InsertValueOp::verify() { 1480 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1481 getPositionAttr(), *this); 1482 if (!valueType) 1483 return failure(); 1484 1485 if (getValue().getType() != valueType) 1486 return emitOpError() << "Type mismatch: cannot insert " 1487 << getValue().getType() << " into " 1488 << getContainer().getType(); 1489 1490 return success(); 1491 } 1492 1493 //===----------------------------------------------------------------------===// 1494 // Printing, parsing and verification for LLVM::ReturnOp. 1495 //===----------------------------------------------------------------------===// 1496 1497 LogicalResult ReturnOp::verify() { 1498 if (getNumOperands() > 1) 1499 return emitOpError("expected at most 1 operand"); 1500 1501 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1502 Type expectedType = parent.getType().getReturnType(); 1503 if (expectedType.isa<LLVMVoidType>()) { 1504 if (getNumOperands() == 0) 1505 return success(); 1506 InFlightDiagnostic diag = emitOpError("expected no operands"); 1507 diag.attachNote(parent->getLoc()) << "when returning from function"; 1508 return diag; 1509 } 1510 if (getNumOperands() == 0) { 1511 if (expectedType.isa<LLVMVoidType>()) 1512 return success(); 1513 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1514 diag.attachNote(parent->getLoc()) << "when returning from function"; 1515 return diag; 1516 } 1517 if (expectedType != getOperand(0).getType()) { 1518 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1519 diag.attachNote(parent->getLoc()) << "when returning from function"; 1520 return diag; 1521 } 1522 } 1523 return success(); 1524 } 1525 1526 //===----------------------------------------------------------------------===// 1527 // ResumeOp 1528 //===----------------------------------------------------------------------===// 1529 1530 LogicalResult ResumeOp::verify() { 1531 if (!getValue().getDefiningOp<LandingpadOp>()) 1532 return emitOpError("expects landingpad value as operand"); 1533 // No check for personality of function - landingpad op verifies it. 1534 return success(); 1535 } 1536 1537 //===----------------------------------------------------------------------===// 1538 // Verifier for LLVM::AddressOfOp. 1539 //===----------------------------------------------------------------------===// 1540 1541 template <typename OpTy> 1542 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1543 Operation *module = parent; 1544 while (module && !satisfiesLLVMModule(module)) 1545 module = module->getParentOp(); 1546 assert(module && "unexpected operation outside of a module"); 1547 return dyn_cast_or_null<OpTy>( 1548 mlir::SymbolTable::lookupSymbolIn(module, name)); 1549 } 1550 1551 GlobalOp AddressOfOp::getGlobal() { 1552 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1553 getGlobalName()); 1554 } 1555 1556 LLVMFuncOp AddressOfOp::getFunction() { 1557 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1558 getGlobalName()); 1559 } 1560 1561 LogicalResult AddressOfOp::verify() { 1562 auto global = getGlobal(); 1563 auto function = getFunction(); 1564 if (!global && !function) 1565 return emitOpError( 1566 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1567 1568 if (global && 1569 LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) != 1570 getResult().getType()) 1571 return emitOpError( 1572 "the type must be a pointer to the type of the referenced global"); 1573 1574 if (function && 1575 LLVM::LLVMPointerType::get(function.getType()) != getResult().getType()) 1576 return emitOpError( 1577 "the type must be a pointer to the type of the referenced function"); 1578 1579 return success(); 1580 } 1581 1582 //===----------------------------------------------------------------------===// 1583 // Builder, printer and verifier for LLVM::GlobalOp. 1584 //===----------------------------------------------------------------------===// 1585 1586 /// Returns the name used for the linkage attribute. This *must* correspond to 1587 /// the name of the attribute in ODS. 1588 static StringRef getLinkageAttrName() { return "linkage"; } 1589 1590 /// Returns the name used for the unnamed_addr attribute. This *must* correspond 1591 /// to the name of the attribute in ODS. 1592 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; } 1593 1594 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1595 bool isConstant, Linkage linkage, StringRef name, 1596 Attribute value, uint64_t alignment, unsigned addrSpace, 1597 bool dsoLocal, ArrayRef<NamedAttribute> attrs) { 1598 result.addAttribute(SymbolTable::getSymbolAttrName(), 1599 builder.getStringAttr(name)); 1600 result.addAttribute("global_type", TypeAttr::get(type)); 1601 if (isConstant) 1602 result.addAttribute("constant", builder.getUnitAttr()); 1603 if (value) 1604 result.addAttribute("value", value); 1605 if (dsoLocal) 1606 result.addAttribute("dso_local", builder.getUnitAttr()); 1607 1608 // Only add an alignment attribute if the "alignment" input 1609 // is different from 0. The value must also be a power of two, but 1610 // this is tested in GlobalOp::verify, not here. 1611 if (alignment != 0) 1612 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 1613 1614 result.addAttribute(::getLinkageAttrName(), 1615 LinkageAttr::get(builder.getContext(), linkage)); 1616 if (addrSpace != 0) 1617 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); 1618 result.attributes.append(attrs.begin(), attrs.end()); 1619 result.addRegion(); 1620 } 1621 1622 void GlobalOp::print(OpAsmPrinter &p) { 1623 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1624 if (auto unnamedAddr = getUnnamedAddr()) { 1625 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1626 if (!str.empty()) 1627 p << str << ' '; 1628 } 1629 if (getConstant()) 1630 p << "constant "; 1631 p.printSymbolName(getSymName()); 1632 p << '('; 1633 if (auto value = getValueOrNull()) 1634 p.printAttribute(value); 1635 p << ')'; 1636 // Note that the alignment attribute is printed using the 1637 // default syntax here, even though it is an inherent attribute 1638 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1639 p.printOptionalAttrDict((*this)->getAttrs(), 1640 {SymbolTable::getSymbolAttrName(), "global_type", 1641 "constant", "value", getLinkageAttrName(), 1642 getUnnamedAddrAttrName()}); 1643 1644 // Print the trailing type unless it's a string global. 1645 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1646 return; 1647 p << " : " << getType(); 1648 1649 Region &initializer = getInitializerRegion(); 1650 if (!initializer.empty()) { 1651 p << ' '; 1652 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1653 } 1654 } 1655 1656 // Parses one of the keywords provided in the list `keywords` and returns the 1657 // position of the parsed keyword in the list. If none of the keywords from the 1658 // list is parsed, returns -1. 1659 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1660 ArrayRef<StringRef> keywords) { 1661 for (const auto &en : llvm::enumerate(keywords)) { 1662 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1663 return en.index(); 1664 } 1665 return -1; 1666 } 1667 1668 namespace { 1669 template <typename Ty> 1670 struct EnumTraits {}; 1671 1672 #define REGISTER_ENUM_TYPE(Ty) \ 1673 template <> \ 1674 struct EnumTraits<Ty> { \ 1675 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1676 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1677 } 1678 1679 REGISTER_ENUM_TYPE(Linkage); 1680 REGISTER_ENUM_TYPE(UnnamedAddr); 1681 } // namespace 1682 1683 /// Parse an enum from the keyword, or default to the provided default value. 1684 /// The return type is the enum type by default, unless overriden with the 1685 /// second template argument. 1686 template <typename EnumTy, typename RetTy = EnumTy> 1687 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1688 OperationState &result, 1689 EnumTy defaultValue) { 1690 SmallVector<StringRef, 10> names; 1691 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1692 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1693 1694 int index = parseOptionalKeywordAlternative(parser, names); 1695 if (index == -1) 1696 return static_cast<RetTy>(defaultValue); 1697 return static_cast<RetTy>(index); 1698 } 1699 1700 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1701 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1702 // align ::= `align` `=` UINT64 1703 // 1704 // The type can be omitted for string attributes, in which case it will be 1705 // inferred from the value of the string as [strlen(value) x i8]. 1706 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1707 MLIRContext *ctx = parser.getContext(); 1708 // Parse optional linkage, default to External. 1709 result.addAttribute(::getLinkageAttrName(), 1710 LLVM::LinkageAttr::get( 1711 ctx, parseOptionalLLVMKeyword<Linkage>( 1712 parser, result, LLVM::Linkage::External))); 1713 // Parse optional UnnamedAddr, default to None. 1714 result.addAttribute(::getUnnamedAddrAttrName(), 1715 parser.getBuilder().getI64IntegerAttr( 1716 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1717 parser, result, LLVM::UnnamedAddr::None))); 1718 1719 if (succeeded(parser.parseOptionalKeyword("constant"))) 1720 result.addAttribute("constant", parser.getBuilder().getUnitAttr()); 1721 1722 StringAttr name; 1723 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), 1724 result.attributes) || 1725 parser.parseLParen()) 1726 return failure(); 1727 1728 Attribute value; 1729 if (parser.parseOptionalRParen()) { 1730 if (parser.parseAttribute(value, "value", result.attributes) || 1731 parser.parseRParen()) 1732 return failure(); 1733 } 1734 1735 SmallVector<Type, 1> types; 1736 if (parser.parseOptionalAttrDict(result.attributes) || 1737 parser.parseOptionalColonTypeList(types)) 1738 return failure(); 1739 1740 if (types.size() > 1) 1741 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1742 1743 Region &initRegion = *result.addRegion(); 1744 if (types.empty()) { 1745 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1746 MLIRContext *context = parser.getContext(); 1747 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1748 strAttr.getValue().size()); 1749 types.push_back(arrayType); 1750 } else { 1751 return parser.emitError(parser.getNameLoc(), 1752 "type can only be omitted for string globals"); 1753 } 1754 } else { 1755 OptionalParseResult parseResult = 1756 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1757 /*argTypes=*/{}); 1758 if (parseResult.hasValue() && failed(*parseResult)) 1759 return failure(); 1760 } 1761 1762 result.addAttribute("global_type", TypeAttr::get(types[0])); 1763 return success(); 1764 } 1765 1766 static bool isZeroAttribute(Attribute value) { 1767 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1768 return intValue.getValue().isNullValue(); 1769 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1770 return fpValue.getValue().isZero(); 1771 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1772 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1773 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1774 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1775 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1776 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1777 return false; 1778 } 1779 1780 LogicalResult GlobalOp::verify() { 1781 if (!LLVMPointerType::isValidElementType(getType())) 1782 return emitOpError( 1783 "expects type to be a valid element type for an LLVM pointer"); 1784 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1785 return emitOpError("must appear at the module level"); 1786 1787 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1788 auto type = getType().dyn_cast<LLVMArrayType>(); 1789 IntegerType elementType = 1790 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1791 if (!elementType || elementType.getWidth() != 8 || 1792 type.getNumElements() != strAttr.getValue().size()) 1793 return emitOpError( 1794 "requires an i8 array type of the length equal to that of the string " 1795 "attribute"); 1796 } 1797 1798 if (getLinkage() == Linkage::Common) { 1799 if (Attribute value = getValueOrNull()) { 1800 if (!isZeroAttribute(value)) { 1801 return emitOpError() 1802 << "expected zero value for '" 1803 << stringifyLinkage(Linkage::Common) << "' linkage"; 1804 } 1805 } 1806 } 1807 1808 if (getLinkage() == Linkage::Appending) { 1809 if (!getType().isa<LLVMArrayType>()) { 1810 return emitOpError() << "expected array type for '" 1811 << stringifyLinkage(Linkage::Appending) 1812 << "' linkage"; 1813 } 1814 } 1815 1816 Optional<uint64_t> alignAttr = getAlignment(); 1817 if (alignAttr.hasValue()) { 1818 uint64_t value = alignAttr.getValue(); 1819 if (!llvm::isPowerOf2_64(value)) 1820 return emitError() << "alignment attribute is not a power of 2"; 1821 } 1822 1823 return success(); 1824 } 1825 1826 LogicalResult GlobalOp::verifyRegions() { 1827 if (Block *b = getInitializerBlock()) { 1828 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1829 if (ret.operand_type_begin() == ret.operand_type_end()) 1830 return emitOpError("initializer region cannot return void"); 1831 if (*ret.operand_type_begin() != getType()) 1832 return emitOpError("initializer region type ") 1833 << *ret.operand_type_begin() << " does not match global type " 1834 << getType(); 1835 1836 for (Operation &op : *b) { 1837 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 1838 if (!iface || !iface.hasNoEffect()) 1839 return op.emitError() 1840 << "ops with side effects not allowed in global initializers"; 1841 } 1842 1843 if (getValueOrNull()) 1844 return emitOpError("cannot have both initializer value and region"); 1845 } 1846 1847 return success(); 1848 } 1849 1850 //===----------------------------------------------------------------------===// 1851 // LLVM::GlobalCtorsOp 1852 //===----------------------------------------------------------------------===// 1853 1854 LogicalResult 1855 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1856 for (Attribute ctor : getCtors()) { 1857 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 1858 symbolTable))) 1859 return failure(); 1860 } 1861 return success(); 1862 } 1863 1864 LogicalResult GlobalCtorsOp::verify() { 1865 if (getCtors().size() != getPriorities().size()) 1866 return emitError( 1867 "mismatch between the number of ctors and the number of priorities"); 1868 return success(); 1869 } 1870 1871 //===----------------------------------------------------------------------===// 1872 // LLVM::GlobalDtorsOp 1873 //===----------------------------------------------------------------------===// 1874 1875 LogicalResult 1876 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1877 for (Attribute dtor : getDtors()) { 1878 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 1879 symbolTable))) 1880 return failure(); 1881 } 1882 return success(); 1883 } 1884 1885 LogicalResult GlobalDtorsOp::verify() { 1886 if (getDtors().size() != getPriorities().size()) 1887 return emitError( 1888 "mismatch between the number of dtors and the number of priorities"); 1889 return success(); 1890 } 1891 1892 //===----------------------------------------------------------------------===// 1893 // Printing/parsing for LLVM::ShuffleVectorOp. 1894 //===----------------------------------------------------------------------===// 1895 // Expects vector to be of wrapped LLVM vector type and position to be of 1896 // wrapped LLVM i32 type. 1897 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 1898 Value v1, Value v2, ArrayAttr mask, 1899 ArrayRef<NamedAttribute> attrs) { 1900 auto containerType = v1.getType(); 1901 auto vType = LLVM::getVectorType( 1902 LLVM::getVectorElementType(containerType), mask.size(), 1903 containerType.cast<VectorType>().isScalable()); 1904 build(b, result, vType, v1, v2, mask); 1905 result.addAttributes(attrs); 1906 } 1907 1908 void ShuffleVectorOp::print(OpAsmPrinter &p) { 1909 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 1910 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 1911 p << " : " << getV1().getType() << ", " << getV2().getType(); 1912 } 1913 1914 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 1915 // `[` integer-literal (`,` integer-literal)* `]` 1916 // attribute-dict? `:` type 1917 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 1918 OperationState &result) { 1919 SMLoc loc; 1920 OpAsmParser::OperandType v1, v2; 1921 ArrayAttr maskAttr; 1922 Type typeV1, typeV2; 1923 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 1924 parser.parseComma() || parser.parseOperand(v2) || 1925 parser.parseAttribute(maskAttr, "mask", result.attributes) || 1926 parser.parseOptionalAttrDict(result.attributes) || 1927 parser.parseColonType(typeV1) || parser.parseComma() || 1928 parser.parseType(typeV2) || 1929 parser.resolveOperand(v1, typeV1, result.operands) || 1930 parser.resolveOperand(v2, typeV2, result.operands)) 1931 return failure(); 1932 if (!LLVM::isCompatibleVectorType(typeV1)) 1933 return parser.emitError( 1934 loc, "expected LLVM IR dialect vector type for operand #1"); 1935 auto vType = 1936 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 1937 typeV1.cast<VectorType>().isScalable()); 1938 result.addTypes(vType); 1939 return success(); 1940 } 1941 1942 LogicalResult ShuffleVectorOp::verify() { 1943 Type type1 = getV1().getType(); 1944 Type type2 = getV2().getType(); 1945 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 1946 return emitOpError("expected matching LLVM IR Dialect element types"); 1947 if (LLVM::isScalableVectorType(type1)) 1948 if (llvm::any_of(getMask(), [](Attribute attr) { 1949 return attr.cast<IntegerAttr>().getInt() != 0; 1950 })) 1951 return emitOpError("expected a splat operation for scalable vectors"); 1952 return success(); 1953 } 1954 1955 //===----------------------------------------------------------------------===// 1956 // Implementations for LLVM::LLVMFuncOp. 1957 //===----------------------------------------------------------------------===// 1958 1959 // Add the entry block to the function. 1960 Block *LLVMFuncOp::addEntryBlock() { 1961 assert(empty() && "function already has an entry block"); 1962 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 1963 1964 auto *entry = new Block; 1965 push_back(entry); 1966 1967 // FIXME: Allow passing in proper locations for the entry arguments. 1968 LLVMFunctionType type = getType(); 1969 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 1970 entry->addArgument(type.getParamType(i), getLoc()); 1971 return entry; 1972 } 1973 1974 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 1975 StringRef name, Type type, LLVM::Linkage linkage, 1976 bool dsoLocal, ArrayRef<NamedAttribute> attrs, 1977 ArrayRef<DictionaryAttr> argAttrs) { 1978 result.addRegion(); 1979 result.addAttribute(SymbolTable::getSymbolAttrName(), 1980 builder.getStringAttr(name)); 1981 result.addAttribute("type", TypeAttr::get(type)); 1982 result.addAttribute(::getLinkageAttrName(), 1983 LinkageAttr::get(builder.getContext(), linkage)); 1984 result.attributes.append(attrs.begin(), attrs.end()); 1985 if (dsoLocal) 1986 result.addAttribute("dso_local", builder.getUnitAttr()); 1987 if (argAttrs.empty()) 1988 return; 1989 1990 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 1991 "expected as many argument attribute lists as arguments"); 1992 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 1993 /*resultAttrs=*/llvm::None); 1994 } 1995 1996 // Builds an LLVM function type from the given lists of input and output types. 1997 // Returns a null type if any of the types provided are non-LLVM types, or if 1998 // there is more than one output type. 1999 static Type 2000 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2001 ArrayRef<Type> outputs, 2002 function_interface_impl::VariadicFlag variadicFlag) { 2003 Builder &b = parser.getBuilder(); 2004 if (outputs.size() > 1) { 2005 parser.emitError(loc, "failed to construct function type: expected zero or " 2006 "one function result"); 2007 return {}; 2008 } 2009 2010 // Convert inputs to LLVM types, exit early on error. 2011 SmallVector<Type, 4> llvmInputs; 2012 for (auto t : inputs) { 2013 if (!isCompatibleType(t)) { 2014 parser.emitError(loc, "failed to construct function type: expected LLVM " 2015 "type for function arguments"); 2016 return {}; 2017 } 2018 llvmInputs.push_back(t); 2019 } 2020 2021 // No output is denoted as "void" in LLVM type system. 2022 Type llvmOutput = 2023 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2024 if (!isCompatibleType(llvmOutput)) { 2025 parser.emitError(loc, "failed to construct function type: expected LLVM " 2026 "type for function results") 2027 << llvmOutput; 2028 return {}; 2029 } 2030 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2031 variadicFlag.isVariadic()); 2032 } 2033 2034 // Parses an LLVM function. 2035 // 2036 // operation ::= `llvm.func` linkage? function-signature function-attributes? 2037 // function-body 2038 // 2039 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2040 // Default to external linkage if no keyword is provided. 2041 result.addAttribute( 2042 ::getLinkageAttrName(), 2043 LinkageAttr::get(parser.getContext(), 2044 parseOptionalLLVMKeyword<Linkage>( 2045 parser, result, LLVM::Linkage::External))); 2046 2047 StringAttr nameAttr; 2048 SmallVector<OpAsmParser::OperandType> entryArgs; 2049 SmallVector<NamedAttrList> argAttrs; 2050 SmallVector<NamedAttrList> resultAttrs; 2051 SmallVector<Type> argTypes; 2052 SmallVector<Type> resultTypes; 2053 SmallVector<Location> argLocations; 2054 bool isVariadic; 2055 2056 auto signatureLocation = parser.getCurrentLocation(); 2057 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2058 result.attributes) || 2059 function_interface_impl::parseFunctionSignature( 2060 parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, 2061 argLocations, isVariadic, resultTypes, resultAttrs)) 2062 return failure(); 2063 2064 auto type = 2065 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2066 function_interface_impl::VariadicFlag(isVariadic)); 2067 if (!type) 2068 return failure(); 2069 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2070 TypeAttr::get(type)); 2071 2072 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2073 return failure(); 2074 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2075 argAttrs, resultAttrs); 2076 2077 auto *body = result.addRegion(); 2078 OptionalParseResult parseResult = parser.parseOptionalRegion( 2079 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 2080 return failure(parseResult.hasValue() && failed(*parseResult)); 2081 } 2082 2083 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2084 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2085 // the external linkage since it is the default value. 2086 void LLVMFuncOp::print(OpAsmPrinter &p) { 2087 p << ' '; 2088 if (getLinkage() != LLVM::Linkage::External) 2089 p << stringifyLinkage(getLinkage()) << ' '; 2090 p.printSymbolName(getName()); 2091 2092 LLVMFunctionType fnType = getType(); 2093 SmallVector<Type, 8> argTypes; 2094 SmallVector<Type, 1> resTypes; 2095 argTypes.reserve(fnType.getNumParams()); 2096 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2097 argTypes.push_back(fnType.getParamType(i)); 2098 2099 Type returnType = fnType.getReturnType(); 2100 if (!returnType.isa<LLVMVoidType>()) 2101 resTypes.push_back(returnType); 2102 2103 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2104 isVarArg(), resTypes); 2105 function_interface_impl::printFunctionAttributes( 2106 p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); 2107 2108 // Print the body if this is not an external function. 2109 Region &body = getBody(); 2110 if (!body.empty()) { 2111 p << ' '; 2112 p.printRegion(body, /*printEntryBlockArgs=*/false, 2113 /*printBlockTerminators=*/true); 2114 } 2115 } 2116 2117 LogicalResult LLVMFuncOp::verifyType() { 2118 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>(); 2119 if (!llvmType) 2120 return emitOpError("requires '" + getTypeAttrName() + 2121 "' attribute of wrapped LLVM function type"); 2122 2123 return success(); 2124 } 2125 2126 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2127 // - functions don't have 'common' linkage 2128 // - external functions have 'external' or 'extern_weak' linkage; 2129 // - vararg is (currently) only supported for external functions; 2130 LogicalResult LLVMFuncOp::verify() { 2131 if (getLinkage() == LLVM::Linkage::Common) 2132 return emitOpError() << "functions cannot have '" 2133 << stringifyLinkage(LLVM::Linkage::Common) 2134 << "' linkage"; 2135 2136 // Check to see if this function has a void return with a result attribute to 2137 // it. It isn't clear what semantics we would assign to that. 2138 if (getType().getReturnType().isa<LLVMVoidType>() && 2139 !getResultAttrs(0).empty()) { 2140 return emitOpError() 2141 << "cannot attach result attributes to functions with a void return"; 2142 } 2143 2144 if (isExternal()) { 2145 if (getLinkage() != LLVM::Linkage::External && 2146 getLinkage() != LLVM::Linkage::ExternWeak) 2147 return emitOpError() << "external functions must have '" 2148 << stringifyLinkage(LLVM::Linkage::External) 2149 << "' or '" 2150 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2151 << "' linkage"; 2152 return success(); 2153 } 2154 2155 if (isVarArg()) 2156 return emitOpError("only external functions can be variadic"); 2157 2158 return success(); 2159 } 2160 2161 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2162 /// - entry block arguments are of LLVM types. 2163 LogicalResult LLVMFuncOp::verifyRegions() { 2164 if (isExternal()) 2165 return success(); 2166 2167 unsigned numArguments = getType().getNumParams(); 2168 Block &entryBlock = front(); 2169 for (unsigned i = 0; i < numArguments; ++i) { 2170 Type argType = entryBlock.getArgument(i).getType(); 2171 if (!isCompatibleType(argType)) 2172 return emitOpError("entry block argument #") 2173 << i << " is not of LLVM type"; 2174 } 2175 2176 return success(); 2177 } 2178 2179 //===----------------------------------------------------------------------===// 2180 // Verification for LLVM::ConstantOp. 2181 //===----------------------------------------------------------------------===// 2182 2183 LogicalResult LLVM::ConstantOp::verify() { 2184 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2185 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2186 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2187 !arrayType.getElementType().isInteger(8)) { 2188 return emitOpError() << "expected array type of " 2189 << sAttr.getValue().size() 2190 << " i8 elements for the string constant"; 2191 } 2192 return success(); 2193 } 2194 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2195 if (structType.getBody().size() != 2 || 2196 structType.getBody()[0] != structType.getBody()[1]) { 2197 return emitError() << "expected struct type with two elements of the " 2198 "same type, the type of a complex constant"; 2199 } 2200 2201 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2202 if (!arrayAttr || arrayAttr.size() != 2 || 2203 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2204 return emitOpError() << "expected array attribute with two elements, " 2205 "representing a complex constant"; 2206 } 2207 2208 Type elementType = structType.getBody()[0]; 2209 if (!elementType 2210 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2211 return emitError() 2212 << "expected struct element types to be floating point type or " 2213 "integer type"; 2214 } 2215 return success(); 2216 } 2217 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2218 return emitOpError() 2219 << "only supports integer, float, string or elements attributes"; 2220 return success(); 2221 } 2222 2223 // Constant op constant-folds to its value. 2224 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2225 2226 //===----------------------------------------------------------------------===// 2227 // Utility functions for parsing atomic ops 2228 //===----------------------------------------------------------------------===// 2229 2230 // Helper function to parse a keyword into the specified attribute named by 2231 // `attrName`. The keyword must match one of the string values defined by the 2232 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2233 // state. 2234 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2235 StringRef attrName) { 2236 SMLoc loc; 2237 StringRef keyword; 2238 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2239 return failure(); 2240 2241 // Replace the keyword `keyword` with an integer attribute. 2242 auto kind = symbolizeAtomicBinOp(keyword); 2243 if (!kind) { 2244 return parser.emitError(loc) 2245 << "'" << keyword << "' is an incorrect value of the '" << attrName 2246 << "' attribute"; 2247 } 2248 2249 auto value = static_cast<int64_t>(kind.getValue()); 2250 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2251 result.addAttribute(attrName, attr); 2252 2253 return success(); 2254 } 2255 2256 // Helper function to parse a keyword into the specified attribute named by 2257 // `attrName`. The keyword must match one of the string values defined by the 2258 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2259 // state. 2260 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2261 OperationState &result, 2262 StringRef attrName) { 2263 SMLoc loc; 2264 StringRef ordering; 2265 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2266 return failure(); 2267 2268 // Replace the keyword `ordering` with an integer attribute. 2269 auto kind = symbolizeAtomicOrdering(ordering); 2270 if (!kind) { 2271 return parser.emitError(loc) 2272 << "'" << ordering << "' is an incorrect value of the '" << attrName 2273 << "' attribute"; 2274 } 2275 2276 auto value = static_cast<int64_t>(kind.getValue()); 2277 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2278 result.addAttribute(attrName, attr); 2279 2280 return success(); 2281 } 2282 2283 //===----------------------------------------------------------------------===// 2284 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2285 //===----------------------------------------------------------------------===// 2286 2287 void AtomicRMWOp::print(OpAsmPrinter &p) { 2288 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2289 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2290 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2291 p << " : " << getRes().getType(); 2292 } 2293 2294 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2295 // attribute-dict? `:` type 2296 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2297 Type type; 2298 OpAsmParser::OperandType ptr, val; 2299 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2300 parser.parseComma() || parser.parseOperand(val) || 2301 parseAtomicOrdering(parser, result, "ordering") || 2302 parser.parseOptionalAttrDict(result.attributes) || 2303 parser.parseColonType(type) || 2304 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2305 result.operands) || 2306 parser.resolveOperand(val, type, result.operands)) 2307 return failure(); 2308 2309 result.addTypes(type); 2310 return success(); 2311 } 2312 2313 LogicalResult AtomicRMWOp::verify() { 2314 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2315 auto valType = getVal().getType(); 2316 if (valType != ptrType.getElementType()) 2317 return emitOpError("expected LLVM IR element type for operand #0 to " 2318 "match type for operand #1"); 2319 auto resType = getRes().getType(); 2320 if (resType != valType) 2321 return emitOpError( 2322 "expected LLVM IR result type to match type for operand #1"); 2323 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2324 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2325 return emitOpError("expected LLVM IR floating point type"); 2326 } else if (getBinOp() == AtomicBinOp::xchg) { 2327 auto intType = valType.dyn_cast<IntegerType>(); 2328 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2329 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2330 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2331 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2332 !valType.isa<Float64Type>()) 2333 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2334 } else { 2335 auto intType = valType.dyn_cast<IntegerType>(); 2336 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2337 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2338 intBitWidth != 64) 2339 return emitOpError("expected LLVM IR integer type"); 2340 } 2341 2342 if (static_cast<unsigned>(getOrdering()) < 2343 static_cast<unsigned>(AtomicOrdering::monotonic)) 2344 return emitOpError() << "expected at least '" 2345 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2346 << "' ordering"; 2347 2348 return success(); 2349 } 2350 2351 //===----------------------------------------------------------------------===// 2352 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2353 //===----------------------------------------------------------------------===// 2354 2355 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2356 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2357 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2358 << stringifyAtomicOrdering(getFailureOrdering()); 2359 p.printOptionalAttrDict((*this)->getAttrs(), 2360 {"success_ordering", "failure_ordering"}); 2361 p << " : " << getVal().getType(); 2362 } 2363 2364 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2365 // keyword keyword attribute-dict? `:` type 2366 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2367 OperationState &result) { 2368 auto &builder = parser.getBuilder(); 2369 Type type; 2370 OpAsmParser::OperandType ptr, cmp, val; 2371 if (parser.parseOperand(ptr) || parser.parseComma() || 2372 parser.parseOperand(cmp) || parser.parseComma() || 2373 parser.parseOperand(val) || 2374 parseAtomicOrdering(parser, result, "success_ordering") || 2375 parseAtomicOrdering(parser, result, "failure_ordering") || 2376 parser.parseOptionalAttrDict(result.attributes) || 2377 parser.parseColonType(type) || 2378 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2379 result.operands) || 2380 parser.resolveOperand(cmp, type, result.operands) || 2381 parser.resolveOperand(val, type, result.operands)) 2382 return failure(); 2383 2384 auto boolType = IntegerType::get(builder.getContext(), 1); 2385 auto resultType = 2386 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2387 result.addTypes(resultType); 2388 2389 return success(); 2390 } 2391 2392 LogicalResult AtomicCmpXchgOp::verify() { 2393 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2394 if (!ptrType) 2395 return emitOpError("expected LLVM IR pointer type for operand #0"); 2396 auto cmpType = getCmp().getType(); 2397 auto valType = getVal().getType(); 2398 if (cmpType != ptrType.getElementType() || cmpType != valType) 2399 return emitOpError("expected LLVM IR element type for operand #0 to " 2400 "match type for all other operands"); 2401 auto intType = valType.dyn_cast<IntegerType>(); 2402 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2403 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2404 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2405 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2406 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2407 return emitOpError("unexpected LLVM IR type"); 2408 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2409 getFailureOrdering() < AtomicOrdering::monotonic) 2410 return emitOpError("ordering must be at least 'monotonic'"); 2411 if (getFailureOrdering() == AtomicOrdering::release || 2412 getFailureOrdering() == AtomicOrdering::acq_rel) 2413 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2414 return success(); 2415 } 2416 2417 //===----------------------------------------------------------------------===// 2418 // Printer, parser and verifier for LLVM::FenceOp. 2419 //===----------------------------------------------------------------------===// 2420 2421 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2422 // attribute-dict? 2423 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2424 StringAttr sScope; 2425 StringRef syncscopeKeyword = "syncscope"; 2426 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2427 if (parser.parseLParen() || 2428 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2429 parser.parseRParen()) 2430 return failure(); 2431 } else { 2432 result.addAttribute(syncscopeKeyword, 2433 parser.getBuilder().getStringAttr("")); 2434 } 2435 if (parseAtomicOrdering(parser, result, "ordering") || 2436 parser.parseOptionalAttrDict(result.attributes)) 2437 return failure(); 2438 return success(); 2439 } 2440 2441 void FenceOp::print(OpAsmPrinter &p) { 2442 StringRef syncscopeKeyword = "syncscope"; 2443 p << ' '; 2444 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2445 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2446 p << stringifyAtomicOrdering(getOrdering()); 2447 } 2448 2449 LogicalResult FenceOp::verify() { 2450 if (getOrdering() == AtomicOrdering::not_atomic || 2451 getOrdering() == AtomicOrdering::unordered || 2452 getOrdering() == AtomicOrdering::monotonic) 2453 return emitOpError("can be given only acquire, release, acq_rel, " 2454 "and seq_cst orderings"); 2455 return success(); 2456 } 2457 2458 //===----------------------------------------------------------------------===// 2459 // Folder for LLVM::BitcastOp 2460 //===----------------------------------------------------------------------===// 2461 2462 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2463 // bitcast(x : T0, T0) -> x 2464 if (getArg().getType() == getType()) 2465 return getArg(); 2466 // bitcast(bitcast(x : T0, T1), T0) -> x 2467 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2468 if (prev.getArg().getType() == getType()) 2469 return prev.getArg(); 2470 return {}; 2471 } 2472 2473 //===----------------------------------------------------------------------===// 2474 // Folder for LLVM::AddrSpaceCastOp 2475 //===----------------------------------------------------------------------===// 2476 2477 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2478 // addrcast(x : T0, T0) -> x 2479 if (getArg().getType() == getType()) 2480 return getArg(); 2481 // addrcast(addrcast(x : T0, T1), T0) -> x 2482 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2483 if (prev.getArg().getType() == getType()) 2484 return prev.getArg(); 2485 return {}; 2486 } 2487 2488 //===----------------------------------------------------------------------===// 2489 // Folder for LLVM::GEPOp 2490 //===----------------------------------------------------------------------===// 2491 2492 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2493 // gep %x:T, 0 -> %x 2494 if (getBase().getType() == getType() && getIndices().size() == 1 && 2495 matchPattern(getIndices()[0], m_Zero())) 2496 return getBase(); 2497 return {}; 2498 } 2499 2500 //===----------------------------------------------------------------------===// 2501 // LLVMDialect initialization, type parsing, and registration. 2502 //===----------------------------------------------------------------------===// 2503 2504 void LLVMDialect::initialize() { 2505 addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>(); 2506 2507 // clang-format off 2508 addTypes<LLVMVoidType, 2509 LLVMPPCFP128Type, 2510 LLVMX86MMXType, 2511 LLVMTokenType, 2512 LLVMLabelType, 2513 LLVMMetadataType, 2514 LLVMFunctionType, 2515 LLVMPointerType, 2516 LLVMFixedVectorType, 2517 LLVMScalableVectorType, 2518 LLVMArrayType, 2519 LLVMStructType>(); 2520 // clang-format on 2521 addOperations< 2522 #define GET_OP_LIST 2523 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2524 >(); 2525 2526 // Support unknown operations because not all LLVM operations are registered. 2527 allowUnknownOperations(); 2528 } 2529 2530 #define GET_OP_CLASSES 2531 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2532 2533 /// Parse a type registered to this dialect. 2534 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2535 return detail::parseType(parser); 2536 } 2537 2538 /// Print a type registered to this dialect. 2539 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2540 return detail::printType(type, os); 2541 } 2542 2543 LogicalResult LLVMDialect::verifyDataLayoutString( 2544 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2545 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2546 llvm::DataLayout::parse(descr); 2547 if (maybeDataLayout) 2548 return success(); 2549 2550 std::string message; 2551 llvm::raw_string_ostream messageStream(message); 2552 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2553 reportError("invalid data layout descriptor: " + messageStream.str()); 2554 return failure(); 2555 } 2556 2557 /// Verify LLVM dialect attributes. 2558 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2559 NamedAttribute attr) { 2560 // If the `llvm.loop` attribute is present, enforce the following structure, 2561 // which the module translation can assume. 2562 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2563 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2564 if (!loopAttr) 2565 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2566 << "' to be a dictionary attribute"; 2567 Optional<NamedAttribute> parallelAccessGroup = 2568 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2569 if (parallelAccessGroup.hasValue()) { 2570 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2571 if (!accessGroups) 2572 return op->emitOpError() 2573 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2574 << "' to be an array attribute"; 2575 for (Attribute attr : accessGroups) { 2576 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2577 if (!accessGroupRef) 2578 return op->emitOpError() 2579 << "expected '" << attr << "' to be a symbol reference"; 2580 StringAttr metadataName = accessGroupRef.getRootReference(); 2581 auto metadataOp = 2582 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2583 op->getParentOp(), metadataName); 2584 if (!metadataOp) 2585 return op->emitOpError() 2586 << "expected '" << attr << "' to reference a metadata op"; 2587 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2588 Operation *accessGroupOp = 2589 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2590 if (!accessGroupOp) 2591 return op->emitOpError() 2592 << "expected '" << attr << "' to reference an access_group op"; 2593 } 2594 } 2595 2596 Optional<NamedAttribute> loopOptions = 2597 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2598 if (loopOptions.hasValue() && 2599 !loopOptions->getValue().isa<LoopOptionsAttr>()) 2600 return op->emitOpError() 2601 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2602 << "' to be a `loopopts` attribute"; 2603 } 2604 2605 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2606 return op->emitOpError() 2607 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2608 << "' is permitted only in argument or result attributes"; 2609 } 2610 2611 // If the data layout attribute is present, it must use the LLVM data layout 2612 // syntax. Try parsing it and report errors in case of failure. Users of this 2613 // attribute may assume it is well-formed and can pass it to the (asserting) 2614 // llvm::DataLayout constructor. 2615 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2616 return success(); 2617 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2618 return verifyDataLayoutString( 2619 stringAttr.getValue(), 2620 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2621 2622 return op->emitOpError() << "expected '" 2623 << LLVM::LLVMDialect::getDataLayoutAttrName() 2624 << "' to be a string attribute"; 2625 } 2626 2627 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2628 Type annotatedType) { 2629 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2630 if (!structType) { 2631 const auto emitIncorrectAnnotatedType = [&op]() { 2632 return op->emitError() 2633 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2634 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2635 }; 2636 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2637 if (!ptrType) 2638 return emitIncorrectAnnotatedType(); 2639 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2640 if (!structType) 2641 return emitIncorrectAnnotatedType(); 2642 } 2643 2644 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2645 if (!arrAttrs) 2646 return op->emitError() << "expected '" 2647 << LLVMDialect::getStructAttrsAttrName() 2648 << "' to be an array attribute"; 2649 2650 if (structType.getBody().size() != arrAttrs.size()) 2651 return op->emitError() 2652 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2653 << "' must match the size of the annotated '!llvm.struct'"; 2654 return success(); 2655 } 2656 2657 static LogicalResult verifyFuncOpInterfaceStructAttr( 2658 Operation *op, Attribute attr, 2659 std::function<Type(FunctionOpInterface)> getAnnotatedType) { 2660 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2661 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2662 return op->emitError() << "expected '" 2663 << LLVMDialect::getStructAttrsAttrName() 2664 << "' to be used on function-like operations"; 2665 } 2666 2667 /// Verify LLVMIR function argument attributes. 2668 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2669 unsigned regionIdx, 2670 unsigned argIdx, 2671 NamedAttribute argAttr) { 2672 // Check that llvm.noalias is a unit attribute. 2673 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2674 !argAttr.getValue().isa<UnitAttr>()) 2675 return op->emitError() 2676 << "expected llvm.noalias argument attribute to be a unit attribute"; 2677 // Check that llvm.align is an integer attribute. 2678 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2679 !argAttr.getValue().isa<IntegerAttr>()) 2680 return op->emitError() 2681 << "llvm.align argument attribute of non integer type"; 2682 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2683 return verifyFuncOpInterfaceStructAttr( 2684 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2685 return funcOp.getArgumentTypes()[argIdx]; 2686 }); 2687 } 2688 return success(); 2689 } 2690 2691 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2692 unsigned regionIdx, 2693 unsigned resIdx, 2694 NamedAttribute resAttr) { 2695 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2696 return verifyFuncOpInterfaceStructAttr( 2697 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2698 return funcOp.getResultTypes()[resIdx]; 2699 }); 2700 } 2701 return success(); 2702 } 2703 2704 //===----------------------------------------------------------------------===// 2705 // Utility functions. 2706 //===----------------------------------------------------------------------===// 2707 2708 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2709 StringRef name, StringRef value, 2710 LLVM::Linkage linkage) { 2711 assert(builder.getInsertionBlock() && 2712 builder.getInsertionBlock()->getParentOp() && 2713 "expected builder to point to a block constrained in an op"); 2714 auto module = 2715 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2716 assert(module && "builder points to an op outside of a module"); 2717 2718 // Create the global at the entry of the module. 2719 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2720 MLIRContext *ctx = builder.getContext(); 2721 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2722 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2723 loc, type, /*isConstant=*/true, linkage, name, 2724 builder.getStringAttr(value), /*alignment=*/0); 2725 2726 // Get the pointer to the first character in the global string. 2727 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2728 Value cst0 = builder.create<LLVM::ConstantOp>( 2729 loc, IntegerType::get(ctx, 64), 2730 builder.getIntegerAttr(builder.getIndexType(), 0)); 2731 return builder.create<LLVM::GEPOp>( 2732 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2733 ValueRange{cst0, cst0}); 2734 } 2735 2736 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2737 return op->hasTrait<OpTrait::SymbolTable>() && 2738 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2739 } 2740 2741 static constexpr const FastmathFlags fastmathFlagsList[] = { 2742 // clang-format off 2743 FastmathFlags::nnan, 2744 FastmathFlags::ninf, 2745 FastmathFlags::nsz, 2746 FastmathFlags::arcp, 2747 FastmathFlags::contract, 2748 FastmathFlags::afn, 2749 FastmathFlags::reassoc, 2750 FastmathFlags::fast, 2751 // clang-format on 2752 }; 2753 2754 void FMFAttr::print(AsmPrinter &printer) const { 2755 printer << "<"; 2756 auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { 2757 return bitEnumContains(this->getFlags(), flag); 2758 }); 2759 llvm::interleaveComma(flags, printer, 2760 [&](auto flag) { printer << stringifyEnum(flag); }); 2761 printer << ">"; 2762 } 2763 2764 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2765 if (failed(parser.parseLess())) 2766 return {}; 2767 2768 FastmathFlags flags = {}; 2769 if (failed(parser.parseOptionalGreater())) { 2770 do { 2771 StringRef elemName; 2772 if (failed(parser.parseKeyword(&elemName))) 2773 return {}; 2774 2775 auto elem = symbolizeFastmathFlags(elemName); 2776 if (!elem) { 2777 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2778 << elemName; 2779 return {}; 2780 } 2781 2782 flags = flags | *elem; 2783 } while (succeeded(parser.parseOptionalComma())); 2784 2785 if (failed(parser.parseGreater())) 2786 return {}; 2787 } 2788 2789 return FMFAttr::get(parser.getContext(), flags); 2790 } 2791 2792 void LinkageAttr::print(AsmPrinter &printer) const { 2793 printer << "<"; 2794 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2795 printer << stringifyEnum(getLinkage()); 2796 else 2797 printer << static_cast<uint64_t>(getLinkage()); 2798 printer << ">"; 2799 } 2800 2801 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2802 StringRef elemName; 2803 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2804 parser.parseGreater()) 2805 return {}; 2806 auto elem = linkage::symbolizeLinkage(elemName); 2807 if (!elem) { 2808 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2809 return {}; 2810 } 2811 Linkage linkage = *elem; 2812 return LinkageAttr::get(parser.getContext(), linkage); 2813 } 2814 2815 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2816 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2817 2818 template <typename T> 2819 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2820 Optional<T> value) { 2821 auto option = llvm::find_if( 2822 options, [tag](auto option) { return option.first == tag; }); 2823 if (option != options.end()) { 2824 if (value.hasValue()) 2825 option->second = *value; 2826 else 2827 options.erase(option); 2828 } else { 2829 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2830 } 2831 return *this; 2832 } 2833 2834 LoopOptionsAttrBuilder & 2835 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 2836 return setOption(LoopOptionCase::disable_licm, value); 2837 } 2838 2839 /// Set the `interleave_count` option to the provided value. If no value 2840 /// is provided the option is deleted. 2841 LoopOptionsAttrBuilder & 2842 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 2843 return setOption(LoopOptionCase::interleave_count, count); 2844 } 2845 2846 /// Set the `disable_unroll` option to the provided value. If no value 2847 /// is provided the option is deleted. 2848 LoopOptionsAttrBuilder & 2849 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 2850 return setOption(LoopOptionCase::disable_unroll, value); 2851 } 2852 2853 /// Set the `disable_pipeline` option to the provided value. If no value 2854 /// is provided the option is deleted. 2855 LoopOptionsAttrBuilder & 2856 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 2857 return setOption(LoopOptionCase::disable_pipeline, value); 2858 } 2859 2860 /// Set the `pipeline_initiation_interval` option to the provided value. 2861 /// If no value is provided the option is deleted. 2862 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 2863 Optional<uint64_t> count) { 2864 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 2865 } 2866 2867 template <typename T> 2868 static Optional<T> 2869 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 2870 LoopOptionCase option) { 2871 auto it = 2872 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 2873 return optionPair.first < option; 2874 }); 2875 if (it == options.end()) 2876 return {}; 2877 return static_cast<T>(it->second); 2878 } 2879 2880 Optional<bool> LoopOptionsAttr::disableUnroll() { 2881 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 2882 } 2883 2884 Optional<bool> LoopOptionsAttr::disableLICM() { 2885 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 2886 } 2887 2888 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 2889 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 2890 } 2891 2892 /// Build the LoopOptions Attribute from a sorted array of individual options. 2893 LoopOptionsAttr LoopOptionsAttr::get( 2894 MLIRContext *context, 2895 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 2896 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 2897 "LoopOptionsAttr ctor expects a sorted options array"); 2898 return Base::get(context, sortedOptions); 2899 } 2900 2901 /// Build the LoopOptions Attribute from a sorted array of individual options. 2902 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 2903 LoopOptionsAttrBuilder &optionBuilders) { 2904 llvm::sort(optionBuilders.options, llvm::less_first()); 2905 return Base::get(context, optionBuilders.options); 2906 } 2907 2908 void LoopOptionsAttr::print(AsmPrinter &printer) const { 2909 printer << "<"; 2910 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 2911 printer << stringifyEnum(option.first) << " = "; 2912 switch (option.first) { 2913 case LoopOptionCase::disable_licm: 2914 case LoopOptionCase::disable_unroll: 2915 case LoopOptionCase::disable_pipeline: 2916 printer << (option.second ? "true" : "false"); 2917 break; 2918 case LoopOptionCase::interleave_count: 2919 case LoopOptionCase::pipeline_initiation_interval: 2920 printer << option.second; 2921 break; 2922 } 2923 }); 2924 printer << ">"; 2925 } 2926 2927 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 2928 if (failed(parser.parseLess())) 2929 return {}; 2930 2931 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 2932 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 2933 do { 2934 StringRef optionName; 2935 if (parser.parseKeyword(&optionName)) 2936 return {}; 2937 2938 auto option = symbolizeLoopOptionCase(optionName); 2939 if (!option) { 2940 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 2941 << optionName; 2942 return {}; 2943 } 2944 if (!seenOptions.insert(*option).second) { 2945 parser.emitError(parser.getNameLoc(), "loop option present twice"); 2946 return {}; 2947 } 2948 if (failed(parser.parseEqual())) 2949 return {}; 2950 2951 int64_t value; 2952 switch (*option) { 2953 case LoopOptionCase::disable_licm: 2954 case LoopOptionCase::disable_unroll: 2955 case LoopOptionCase::disable_pipeline: 2956 if (succeeded(parser.parseOptionalKeyword("true"))) 2957 value = 1; 2958 else if (succeeded(parser.parseOptionalKeyword("false"))) 2959 value = 0; 2960 else { 2961 parser.emitError(parser.getNameLoc(), 2962 "expected boolean value 'true' or 'false'"); 2963 return {}; 2964 } 2965 break; 2966 case LoopOptionCase::interleave_count: 2967 case LoopOptionCase::pipeline_initiation_interval: 2968 if (failed(parser.parseInteger(value))) { 2969 parser.emitError(parser.getNameLoc(), "expected integer value"); 2970 return {}; 2971 } 2972 break; 2973 } 2974 options.push_back(std::make_pair(*option, value)); 2975 } while (succeeded(parser.parseOptionalComma())); 2976 if (failed(parser.parseGreater())) 2977 return {}; 2978 2979 llvm::sort(options, llvm::less_first()); 2980 return get(parser.getContext(), options); 2981 } 2982