1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Mutex.h" 33 #include "llvm/Support/SourceMgr.h" 34 35 #include <numeric> 36 37 using namespace mlir; 38 using namespace mlir::LLVM; 39 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 40 41 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 42 43 static constexpr const char kVolatileAttrName[] = "volatile_"; 44 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 45 46 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 47 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 48 #define GET_ATTRDEF_CLASSES 49 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 50 51 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 52 SmallVector<NamedAttribute, 8> filteredAttrs( 53 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 54 if (attr.getName() == "fastmathFlags") { 55 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 56 return defAttr != attr.getValue(); 57 } 58 return true; 59 })); 60 return filteredAttrs; 61 } 62 63 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 64 NamedAttrList &result) { 65 return parser.parseOptionalAttrDict(result); 66 } 67 68 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 69 DictionaryAttr attrs) { 70 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 71 } 72 73 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 74 /// fully defined llvm.func. 75 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 76 Operation *op, 77 SymbolTableCollection &symbolTable) { 78 StringRef name = symbol.getValue(); 79 auto func = 80 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 81 if (!func) 82 return op->emitOpError("'") 83 << name << "' does not reference a valid LLVM function"; 84 if (func.isExternal()) 85 return op->emitOpError("'") << name << "' does not have a definition"; 86 return success(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // Printing/parsing for LLVM::CmpOp. 91 //===----------------------------------------------------------------------===// 92 93 void ICmpOp::print(OpAsmPrinter &p) { 94 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 95 << ", " << getOperand(1); 96 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 97 p << " : " << getLhs().getType(); 98 } 99 100 void FCmpOp::print(OpAsmPrinter &p) { 101 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 102 << ", " << getOperand(1); 103 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 104 p << " : " << getLhs().getType(); 105 } 106 107 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 108 // attribute-dict? `:` type 109 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 110 // attribute-dict? `:` type 111 template <typename CmpPredicateType> 112 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 113 Builder &builder = parser.getBuilder(); 114 115 StringAttr predicateAttr; 116 OpAsmParser::UnresolvedOperand lhs, rhs; 117 Type type; 118 SMLoc predicateLoc, trailingTypeLoc; 119 if (parser.getCurrentLocation(&predicateLoc) || 120 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 121 parser.parseOperand(lhs) || parser.parseComma() || 122 parser.parseOperand(rhs) || 123 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 124 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 125 parser.resolveOperand(lhs, type, result.operands) || 126 parser.resolveOperand(rhs, type, result.operands)) 127 return failure(); 128 129 // Replace the string attribute `predicate` with an integer attribute. 130 int64_t predicateValue = 0; 131 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 132 Optional<ICmpPredicate> predicate = 133 symbolizeICmpPredicate(predicateAttr.getValue()); 134 if (!predicate) 135 return parser.emitError(predicateLoc) 136 << "'" << predicateAttr.getValue() 137 << "' is an incorrect value of the 'predicate' attribute"; 138 predicateValue = static_cast<int64_t>(predicate.getValue()); 139 } else { 140 Optional<FCmpPredicate> predicate = 141 symbolizeFCmpPredicate(predicateAttr.getValue()); 142 if (!predicate) 143 return parser.emitError(predicateLoc) 144 << "'" << predicateAttr.getValue() 145 << "' is an incorrect value of the 'predicate' attribute"; 146 predicateValue = static_cast<int64_t>(predicate.getValue()); 147 } 148 149 result.attributes.set("predicate", 150 parser.getBuilder().getI64IntegerAttr(predicateValue)); 151 152 // The result type is either i1 or a vector type <? x i1> if the inputs are 153 // vectors. 154 Type resultType = IntegerType::get(builder.getContext(), 1); 155 if (!isCompatibleType(type)) 156 return parser.emitError(trailingTypeLoc, 157 "expected LLVM dialect-compatible type"); 158 if (LLVM::isCompatibleVectorType(type)) { 159 if (LLVM::isScalableVectorType(type)) { 160 resultType = LLVM::getVectorType( 161 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 162 /*isScalable=*/true); 163 } else { 164 resultType = LLVM::getVectorType( 165 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 166 /*isScalable=*/false); 167 } 168 } 169 170 result.addTypes({resultType}); 171 return success(); 172 } 173 174 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 175 return parseCmpOp<ICmpPredicate>(parser, result); 176 } 177 178 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 179 return parseCmpOp<FCmpPredicate>(parser, result); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // Printing/parsing for LLVM::AllocaOp. 184 //===----------------------------------------------------------------------===// 185 186 void AllocaOp::print(OpAsmPrinter &p) { 187 auto elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 188 189 auto funcTy = 190 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 191 192 p << ' ' << getArraySize() << " x " << elemTy; 193 if (getAlignment().hasValue() && *getAlignment() != 0) 194 p.printOptionalAttrDict((*this)->getAttrs()); 195 else 196 p.printOptionalAttrDict((*this)->getAttrs(), {"alignment"}); 197 p << " : " << funcTy; 198 } 199 200 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 201 // `:` type `,` type 202 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 203 OpAsmParser::UnresolvedOperand arraySize; 204 Type type, elemType; 205 SMLoc trailingTypeLoc; 206 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 207 parser.parseType(elemType) || 208 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 209 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 210 return failure(); 211 212 Optional<NamedAttribute> alignmentAttr = 213 result.attributes.getNamed("alignment"); 214 if (alignmentAttr.hasValue()) { 215 auto alignmentInt = 216 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 217 if (!alignmentInt) 218 return parser.emitError(parser.getNameLoc(), 219 "expected integer alignment"); 220 if (alignmentInt.getValue().isNullValue()) 221 result.attributes.erase("alignment"); 222 } 223 224 // Extract the result type from the trailing function type. 225 auto funcType = type.dyn_cast<FunctionType>(); 226 if (!funcType || funcType.getNumInputs() != 1 || 227 funcType.getNumResults() != 1) 228 return parser.emitError( 229 trailingTypeLoc, 230 "expected trailing function type with one argument and one result"); 231 232 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 233 return failure(); 234 235 result.addTypes({funcType.getResult(0)}); 236 return success(); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // LLVM::BrOp 241 //===----------------------------------------------------------------------===// 242 243 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 244 assert(index == 0 && "invalid successor index"); 245 return SuccessorOperands(getDestOperandsMutable()); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // LLVM::CondBrOp 250 //===----------------------------------------------------------------------===// 251 252 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 253 assert(index < getNumSuccessors() && "invalid successor index"); 254 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 255 : getFalseDestOperandsMutable()); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // LLVM::SwitchOp 260 //===----------------------------------------------------------------------===// 261 262 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 263 Block *defaultDestination, ValueRange defaultOperands, 264 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 265 ArrayRef<ValueRange> caseOperands, 266 ArrayRef<int32_t> branchWeights) { 267 ElementsAttr caseValuesAttr; 268 if (!caseValues.empty()) 269 caseValuesAttr = builder.getI32VectorAttr(caseValues); 270 271 ElementsAttr weightsAttr; 272 if (!branchWeights.empty()) 273 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 274 275 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 276 weightsAttr, defaultDestination, caseDestinations); 277 } 278 279 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 280 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 281 static ParseResult parseSwitchOpCases( 282 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 283 SmallVectorImpl<Block *> &caseDestinations, 284 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 285 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 286 SmallVector<APInt> values; 287 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 288 do { 289 int64_t value = 0; 290 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 291 if (values.empty() && !integerParseResult.hasValue()) 292 return success(); 293 294 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 295 return failure(); 296 values.push_back(APInt(bitWidth, value)); 297 298 Block *destination; 299 SmallVector<OpAsmParser::UnresolvedOperand> operands; 300 SmallVector<Type> operandTypes; 301 if (parser.parseColon() || parser.parseSuccessor(destination)) 302 return failure(); 303 if (!parser.parseOptionalLParen()) { 304 if (parser.parseRegionArgumentList(operands) || 305 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 306 return failure(); 307 } 308 caseDestinations.push_back(destination); 309 caseOperands.emplace_back(operands); 310 caseOperandTypes.emplace_back(operandTypes); 311 } while (!parser.parseOptionalComma()); 312 313 ShapedType caseValueType = 314 VectorType::get(static_cast<int64_t>(values.size()), flagType); 315 caseValues = DenseIntElementsAttr::get(caseValueType, values); 316 return success(); 317 } 318 319 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 320 ElementsAttr caseValues, 321 SuccessorRange caseDestinations, 322 OperandRangeRange caseOperands, 323 const TypeRangeRange &caseOperandTypes) { 324 if (!caseValues) 325 return; 326 327 size_t index = 0; 328 llvm::interleave( 329 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 330 [&](auto i) { 331 p << " "; 332 p << std::get<0>(i).getLimitedValue(); 333 p << ": "; 334 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 335 }, 336 [&] { 337 p << ','; 338 p.printNewline(); 339 }); 340 p.printNewline(); 341 } 342 343 LogicalResult SwitchOp::verify() { 344 if ((!getCaseValues() && !getCaseDestinations().empty()) || 345 (getCaseValues() && 346 getCaseValues()->size() != 347 static_cast<int64_t>(getCaseDestinations().size()))) 348 return emitOpError("expects number of case values to match number of " 349 "case destinations"); 350 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 351 return emitError("expects number of branch weights to match number of " 352 "successors: ") 353 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 354 return success(); 355 } 356 357 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 358 assert(index < getNumSuccessors() && "invalid successor index"); 359 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 360 : getCaseOperandsMutable(index - 1)); 361 } 362 363 //===----------------------------------------------------------------------===// 364 // Code for LLVM::GEPOp. 365 //===----------------------------------------------------------------------===// 366 367 constexpr int GEPOp::kDynamicIndex; 368 369 /// Populates `indices` with positions of GEP indices that would correspond to 370 /// LLVMStructTypes potentially nested in the given type. The type currently 371 /// visited gets `currentIndex` and LLVM container types are visited 372 /// recursively. The recursion is bounded and takes care of recursive types by 373 /// means of the `visited` set. 374 static void recordStructIndices(Type type, unsigned currentIndex, 375 SmallVectorImpl<unsigned> &indices, 376 SmallVectorImpl<unsigned> *structSizes, 377 SmallPtrSet<Type, 4> &visited) { 378 if (visited.contains(type)) 379 return; 380 381 visited.insert(type); 382 383 llvm::TypeSwitch<Type>(type) 384 .Case<LLVMStructType>([&](LLVMStructType structType) { 385 indices.push_back(currentIndex); 386 if (structSizes) 387 structSizes->push_back(structType.getBody().size()); 388 for (Type elementType : structType.getBody()) 389 recordStructIndices(elementType, currentIndex + 1, indices, 390 structSizes, visited); 391 }) 392 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 393 LLVMArrayType>([&](auto containerType) { 394 recordStructIndices(containerType.getElementType(), currentIndex + 1, 395 indices, structSizes, visited); 396 }); 397 } 398 399 /// Populates `indices` with positions of GEP indices that correspond to 400 /// LLVMStructTypes potentially nested in the given `baseGEPType`, which must 401 /// be either an LLVMPointer type or a vector thereof. If `structSizes` is 402 /// provided, it is populated with sizes of the indexed structs for bounds 403 /// verification purposes. 404 static void 405 findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices, 406 SmallVectorImpl<unsigned> *structSizes = nullptr) { 407 Type type = baseGEPType; 408 if (auto vectorType = type.dyn_cast<VectorType>()) 409 type = vectorType.getElementType(); 410 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 411 type = scalableVectorType.getElementType(); 412 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 413 type = fixedVectorType.getElementType(); 414 415 Type pointeeType = type.cast<LLVMPointerType>().getElementType(); 416 SmallPtrSet<Type, 4> visited; 417 recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes, 418 visited); 419 } 420 421 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 422 Value basePtr, ValueRange operands, 423 ArrayRef<NamedAttribute> attributes) { 424 build(builder, result, resultType, basePtr, operands, 425 SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex), 426 attributes); 427 } 428 429 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 430 Value basePtr, ValueRange indices, 431 ArrayRef<int32_t> structIndices, 432 ArrayRef<NamedAttribute> attributes) { 433 SmallVector<Value> remainingIndices; 434 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 435 structIndices.end()); 436 SmallVector<unsigned> structRelatedPositions; 437 findKnownStructIndices(basePtr.getType(), structRelatedPositions); 438 439 SmallVector<unsigned> operandsToErase; 440 for (unsigned pos : structRelatedPositions) { 441 // GEP may not be indexing as deep as some structs are located. 442 if (pos >= structIndices.size()) 443 continue; 444 445 // If the index is already static, it's fine. 446 if (structIndices[pos] != kDynamicIndex) 447 continue; 448 449 // Find the corresponding operand. 450 unsigned operandPos = 451 std::count(structIndices.begin(), std::next(structIndices.begin(), pos), 452 kDynamicIndex); 453 454 // Extract the constant value from the operand and put it into the attribute 455 // instead. 456 APInt staticIndexValue; 457 bool matched = 458 matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue)); 459 (void)matched; 460 assert(matched && "index into a struct must be a constant"); 461 assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) && 462 "struct index underflows 32-bit integer"); 463 assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) && 464 "struct index overflows 32-bit integer"); 465 auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 466 updatedStructIndices[pos] = staticIndex; 467 operandsToErase.push_back(operandPos); 468 } 469 470 for (unsigned i = 0, e = indices.size(); i < e; ++i) { 471 if (!llvm::is_contained(operandsToErase, i)) 472 remainingIndices.push_back(indices[i]); 473 } 474 475 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 476 updatedStructIndices, kDynamicIndex)) && 477 "expected as many index operands as dynamic index attr elements"); 478 479 result.addTypes(resultType); 480 result.addAttributes(attributes); 481 result.addAttribute("structIndices", 482 builder.getI32TensorAttr(updatedStructIndices)); 483 result.addOperands(basePtr); 484 result.addOperands(remainingIndices); 485 } 486 487 static ParseResult 488 parseGEPIndices(OpAsmParser &parser, 489 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 490 DenseIntElementsAttr &structIndices) { 491 SmallVector<int32_t> constantIndices; 492 do { 493 int32_t constantIndex; 494 OptionalParseResult parsedInteger = 495 parser.parseOptionalInteger(constantIndex); 496 if (parsedInteger.hasValue()) { 497 if (failed(parsedInteger.getValue())) 498 return failure(); 499 constantIndices.push_back(constantIndex); 500 continue; 501 } 502 503 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 504 if (failed(parser.parseOperand(indices.emplace_back()))) 505 return failure(); 506 } while (succeeded(parser.parseOptionalComma())); 507 508 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 509 return success(); 510 } 511 512 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 513 OperandRange indices, 514 DenseIntElementsAttr structIndices) { 515 unsigned operandIdx = 0; 516 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 517 [&](int32_t cst) { 518 if (cst == LLVM::GEPOp::kDynamicIndex) 519 printer.printOperand(indices[operandIdx++]); 520 else 521 printer << cst; 522 }); 523 } 524 525 LogicalResult LLVM::GEPOp::verify() { 526 SmallVector<unsigned> indices; 527 SmallVector<unsigned> structSizes; 528 findKnownStructIndices(getBase().getType(), indices, &structSizes); 529 DenseIntElementsAttr structIndices = getStructIndices(); 530 for (unsigned i : llvm::seq<unsigned>(0, indices.size())) { 531 unsigned index = indices[i]; 532 // GEP may not be indexing as deep as some structs nested in the type. 533 if (index >= structIndices.getNumElements()) 534 continue; 535 536 int32_t staticIndex = structIndices.getValues<int32_t>()[index]; 537 if (staticIndex == LLVM::GEPOp::kDynamicIndex) 538 return emitOpError() << "expected index " << index 539 << " indexing a struct to be constant"; 540 if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i]) 541 return emitOpError() << "index " << index 542 << " indexing a struct is out of bounds"; 543 } 544 return success(); 545 } 546 547 //===----------------------------------------------------------------------===// 548 // Builder, printer and parser for for LLVM::LoadOp. 549 //===----------------------------------------------------------------------===// 550 551 LogicalResult verifySymbolAttribute( 552 Operation *op, StringRef attributeName, 553 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 554 verifySymbolType) { 555 if (Attribute attribute = op->getAttr(attributeName)) { 556 // The attribute is already verified to be a symbol ref array attribute via 557 // a constraint in the operation definition. 558 for (SymbolRefAttr symbolRef : 559 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 560 StringAttr metadataName = symbolRef.getRootReference(); 561 StringAttr symbolName = symbolRef.getLeafReference(); 562 // We want @metadata::@symbol, not just @symbol 563 if (metadataName == symbolName) { 564 return op->emitOpError() << "expected '" << symbolRef 565 << "' to specify a fully qualified reference"; 566 } 567 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 568 op->getParentOp(), metadataName); 569 if (!metadataOp) 570 return op->emitOpError() 571 << "expected '" << symbolRef << "' to reference a metadata op"; 572 Operation *symbolOp = 573 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 574 if (!symbolOp) 575 return op->emitOpError() 576 << "expected '" << symbolRef << "' to be a valid reference"; 577 if (failed(verifySymbolType(symbolOp, symbolRef))) { 578 return failure(); 579 } 580 } 581 } 582 return success(); 583 } 584 585 // Verifies that metadata ops are wired up properly. 586 template <typename OpTy> 587 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 588 auto verifySymbolType = [op](Operation *symbolOp, 589 SymbolRefAttr symbolRef) -> LogicalResult { 590 if (!isa<OpTy>(symbolOp)) { 591 return op->emitOpError() 592 << "expected '" << symbolRef << "' to resolve to a " 593 << OpTy::getOperationName(); 594 } 595 return success(); 596 }; 597 598 return verifySymbolAttribute(op, attributeName, verifySymbolType); 599 } 600 601 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 602 // access_groups 603 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 604 op, LLVMDialect::getAccessGroupsAttrName()))) 605 return failure(); 606 607 // alias_scopes 608 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 609 op, LLVMDialect::getAliasScopesAttrName()))) 610 return failure(); 611 612 // noalias_scopes 613 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 614 op, LLVMDialect::getNoAliasScopesAttrName()))) 615 return failure(); 616 617 return success(); 618 } 619 620 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 621 622 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 623 Value addr, unsigned alignment, bool isVolatile, 624 bool isNonTemporal) { 625 result.addOperands(addr); 626 result.addTypes(t); 627 if (isVolatile) 628 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 629 if (isNonTemporal) 630 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 631 if (alignment != 0) 632 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 633 } 634 635 void LoadOp::print(OpAsmPrinter &p) { 636 p << ' '; 637 if (getVolatile_()) 638 p << "volatile "; 639 p << getAddr(); 640 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 641 p << " : " << getAddr().getType(); 642 } 643 644 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 645 // the resulting type wrapped in MLIR, or nullptr on error. 646 static Type getLoadStoreElementType(OpAsmParser &parser, Type type, 647 SMLoc trailingTypeLoc) { 648 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 649 if (!llvmTy) 650 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), 651 nullptr; 652 return llvmTy.getElementType(); 653 } 654 655 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 656 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 657 OpAsmParser::UnresolvedOperand addr; 658 Type type; 659 SMLoc trailingTypeLoc; 660 661 if (succeeded(parser.parseOptionalKeyword("volatile"))) 662 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 663 664 if (parser.parseOperand(addr) || 665 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 666 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 667 parser.resolveOperand(addr, type, result.operands)) 668 return failure(); 669 670 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 671 672 result.addTypes(elemTy); 673 return success(); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // Builder, printer and parser for LLVM::StoreOp. 678 //===----------------------------------------------------------------------===// 679 680 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 681 682 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 683 Value addr, unsigned alignment, bool isVolatile, 684 bool isNonTemporal) { 685 result.addOperands({value, addr}); 686 result.addTypes({}); 687 if (isVolatile) 688 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 689 if (isNonTemporal) 690 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 691 if (alignment != 0) 692 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 693 } 694 695 void StoreOp::print(OpAsmPrinter &p) { 696 p << ' '; 697 if (getVolatile_()) 698 p << "volatile "; 699 p << getValue() << ", " << getAddr(); 700 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 701 p << " : " << getAddr().getType(); 702 } 703 704 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 705 // attribute-dict? `:` type 706 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 707 OpAsmParser::UnresolvedOperand addr, value; 708 Type type; 709 SMLoc trailingTypeLoc; 710 711 if (succeeded(parser.parseOptionalKeyword("volatile"))) 712 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 713 714 if (parser.parseOperand(value) || parser.parseComma() || 715 parser.parseOperand(addr) || 716 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 717 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 718 return failure(); 719 720 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 721 if (!elemTy) 722 return failure(); 723 724 if (parser.resolveOperand(value, elemTy, result.operands) || 725 parser.resolveOperand(addr, type, result.operands)) 726 return failure(); 727 728 return success(); 729 } 730 731 ///===---------------------------------------------------------------------===// 732 /// LLVM::InvokeOp 733 ///===---------------------------------------------------------------------===// 734 735 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 736 assert(index < getNumSuccessors() && "invalid successor index"); 737 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 738 : getUnwindDestOperandsMutable()); 739 } 740 741 LogicalResult InvokeOp::verify() { 742 if (getNumResults() > 1) 743 return emitOpError("must have 0 or 1 result"); 744 745 Block *unwindDest = getUnwindDest(); 746 if (unwindDest->empty()) 747 return emitError("must have at least one operation in unwind destination"); 748 749 // In unwind destination, first operation must be LandingpadOp 750 if (!isa<LandingpadOp>(unwindDest->front())) 751 return emitError("first operation in unwind destination should be a " 752 "llvm.landingpad operation"); 753 754 return success(); 755 } 756 757 void InvokeOp::print(OpAsmPrinter &p) { 758 auto callee = getCallee(); 759 bool isDirect = callee.hasValue(); 760 761 p << ' '; 762 763 // Either function name or pointer 764 if (isDirect) 765 p.printSymbolName(callee.getValue()); 766 else 767 p << getOperand(0); 768 769 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 770 p << " to "; 771 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 772 p << " unwind "; 773 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 774 775 p.printOptionalAttrDict((*this)->getAttrs(), 776 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 777 p << " : "; 778 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 779 getResultTypes()); 780 } 781 782 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 783 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 784 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 785 /// attribute-dict? `:` function-type 786 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 787 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 788 FunctionType funcType; 789 SymbolRefAttr funcAttr; 790 SMLoc trailingTypeLoc; 791 Block *normalDest, *unwindDest; 792 SmallVector<Value, 4> normalOperands, unwindOperands; 793 Builder &builder = parser.getBuilder(); 794 795 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 796 // case of an indirect call, there will be 1 operand before `(`. In case of a 797 // direct call, there will be no operands and the parser will stop at the 798 // function identifier without complaining. 799 if (parser.parseOperandList(operands)) 800 return failure(); 801 bool isDirect = operands.empty(); 802 803 // Optionally parse a function identifier. 804 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 805 return failure(); 806 807 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 808 parser.parseKeyword("to") || 809 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 810 parser.parseKeyword("unwind") || 811 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 812 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 813 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 814 return failure(); 815 816 if (isDirect) { 817 // Make sure types match. 818 if (parser.resolveOperands(operands, funcType.getInputs(), 819 parser.getNameLoc(), result.operands)) 820 return failure(); 821 result.addTypes(funcType.getResults()); 822 } else { 823 // Construct the LLVM IR Dialect function type that the first operand 824 // should match. 825 if (funcType.getNumResults() > 1) 826 return parser.emitError(trailingTypeLoc, 827 "expected function with 0 or 1 result"); 828 829 Type llvmResultType; 830 if (funcType.getNumResults() == 0) { 831 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 832 } else { 833 llvmResultType = funcType.getResult(0); 834 if (!isCompatibleType(llvmResultType)) 835 return parser.emitError(trailingTypeLoc, 836 "expected result to have LLVM type"); 837 } 838 839 SmallVector<Type, 8> argTypes; 840 argTypes.reserve(funcType.getNumInputs()); 841 for (Type ty : funcType.getInputs()) { 842 if (isCompatibleType(ty)) 843 argTypes.push_back(ty); 844 else 845 return parser.emitError(trailingTypeLoc, 846 "expected LLVM types as inputs"); 847 } 848 849 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 850 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 851 852 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 853 854 // Make sure that the first operand (indirect callee) matches the wrapped 855 // LLVM IR function type, and that the types of the other call operands 856 // match the types of the function arguments. 857 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 858 parser.resolveOperands(funcArguments, funcType.getInputs(), 859 parser.getNameLoc(), result.operands)) 860 return failure(); 861 862 result.addTypes(llvmResultType); 863 } 864 result.addSuccessors({normalDest, unwindDest}); 865 result.addOperands(normalOperands); 866 result.addOperands(unwindOperands); 867 868 result.addAttribute( 869 InvokeOp::getOperandSegmentSizeAttr(), 870 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 871 static_cast<int32_t>(normalOperands.size()), 872 static_cast<int32_t>(unwindOperands.size())})); 873 return success(); 874 } 875 876 ///===----------------------------------------------------------------------===// 877 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 878 ///===----------------------------------------------------------------------===// 879 880 LogicalResult LandingpadOp::verify() { 881 Value value; 882 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 883 if (!func.getPersonality().hasValue()) 884 return emitError( 885 "llvm.landingpad needs to be in a function with a personality"); 886 } 887 888 if (!getCleanup() && getOperands().empty()) 889 return emitError("landingpad instruction expects at least one clause or " 890 "cleanup attribute"); 891 892 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 893 value = getOperand(idx); 894 bool isFilter = value.getType().isa<LLVMArrayType>(); 895 if (isFilter) { 896 // FIXME: Verify filter clauses when arrays are appropriately handled 897 } else { 898 // catch - global addresses only. 899 // Bitcast ops should have global addresses as their args. 900 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 901 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 902 continue; 903 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 904 << "global addresses expected as operand to " 905 "bitcast used in clauses for landingpad"; 906 } 907 // NullOp and AddressOfOp allowed 908 if (value.getDefiningOp<NullOp>()) 909 continue; 910 if (value.getDefiningOp<AddressOfOp>()) 911 continue; 912 return emitError("clause #") 913 << idx << " is not a known constant - null, addressof, bitcast"; 914 } 915 } 916 return success(); 917 } 918 919 void LandingpadOp::print(OpAsmPrinter &p) { 920 p << (getCleanup() ? " cleanup " : " "); 921 922 // Clauses 923 for (auto value : getOperands()) { 924 // Similar to llvm - if clause is an array type then it is filter 925 // clause else catch clause 926 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 927 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 928 << value.getType() << ") "; 929 } 930 931 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 932 933 p << ": " << getType(); 934 } 935 936 /// <operation> ::= `llvm.landingpad` `cleanup`? 937 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 938 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 939 // Check for cleanup 940 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 941 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 942 943 // Parse clauses with types 944 while (succeeded(parser.parseOptionalLParen()) && 945 (succeeded(parser.parseOptionalKeyword("filter")) || 946 succeeded(parser.parseOptionalKeyword("catch")))) { 947 OpAsmParser::UnresolvedOperand operand; 948 Type ty; 949 if (parser.parseOperand(operand) || parser.parseColon() || 950 parser.parseType(ty) || 951 parser.resolveOperand(operand, ty, result.operands) || 952 parser.parseRParen()) 953 return failure(); 954 } 955 956 Type type; 957 if (parser.parseColon() || parser.parseType(type)) 958 return failure(); 959 960 result.addTypes(type); 961 return success(); 962 } 963 964 //===----------------------------------------------------------------------===// 965 // Verifying/Printing/parsing for LLVM::CallOp. 966 //===----------------------------------------------------------------------===// 967 968 LogicalResult CallOp::verify() { 969 if (getNumResults() > 1) 970 return emitOpError("must have 0 or 1 result"); 971 972 // Type for the callee, we'll get it differently depending if it is a direct 973 // or indirect call. 974 Type fnType; 975 976 bool isIndirect = false; 977 978 // If this is an indirect call, the callee attribute is missing. 979 FlatSymbolRefAttr calleeName = getCalleeAttr(); 980 if (!calleeName) { 981 isIndirect = true; 982 if (!getNumOperands()) 983 return emitOpError( 984 "must have either a `callee` attribute or at least an operand"); 985 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 986 if (!ptrType) 987 return emitOpError("indirect call expects a pointer as callee: ") 988 << ptrType; 989 fnType = ptrType.getElementType(); 990 } else { 991 Operation *callee = 992 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 993 if (!callee) 994 return emitOpError() 995 << "'" << calleeName.getValue() 996 << "' does not reference a symbol in the current scope"; 997 auto fn = dyn_cast<LLVMFuncOp>(callee); 998 if (!fn) 999 return emitOpError() << "'" << calleeName.getValue() 1000 << "' does not reference a valid LLVM function"; 1001 1002 fnType = fn.getFunctionType(); 1003 } 1004 1005 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1006 if (!funcType) 1007 return emitOpError("callee does not have a functional type: ") << fnType; 1008 1009 // Verify that the operand and result types match the callee. 1010 1011 if (!funcType.isVarArg() && 1012 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1013 return emitOpError() << "incorrect number of operands (" 1014 << (getNumOperands() - isIndirect) 1015 << ") for callee (expecting: " 1016 << funcType.getNumParams() << ")"; 1017 1018 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1019 return emitOpError() << "incorrect number of operands (" 1020 << (getNumOperands() - isIndirect) 1021 << ") for varargs callee (expecting at least: " 1022 << funcType.getNumParams() << ")"; 1023 1024 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1025 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1026 return emitOpError() << "operand type mismatch for operand " << i << ": " 1027 << getOperand(i + isIndirect).getType() 1028 << " != " << funcType.getParamType(i); 1029 1030 if (getNumResults() == 0 && 1031 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1032 return emitOpError() << "expected function call to produce a value"; 1033 1034 if (getNumResults() != 0 && 1035 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1036 return emitOpError() 1037 << "calling function with void result must not produce values"; 1038 1039 if (getNumResults() > 1) 1040 return emitOpError() 1041 << "expected LLVM function call to produce 0 or 1 result"; 1042 1043 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1044 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1045 << " != " << funcType.getReturnType(); 1046 1047 return success(); 1048 } 1049 1050 void CallOp::print(OpAsmPrinter &p) { 1051 auto callee = getCallee(); 1052 bool isDirect = callee.hasValue(); 1053 1054 // Print the direct callee if present as a function attribute, or an indirect 1055 // callee (first operand) otherwise. 1056 p << ' '; 1057 if (isDirect) 1058 p.printSymbolName(callee.getValue()); 1059 else 1060 p << getOperand(0); 1061 1062 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1063 p << '(' << args << ')'; 1064 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1065 1066 // Reconstruct the function MLIR function type from operand and result types. 1067 p << " : "; 1068 p.printFunctionalType(args.getTypes(), getResultTypes()); 1069 } 1070 1071 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1072 // attribute-dict? `:` function-type 1073 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1074 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1075 Type type; 1076 SymbolRefAttr funcAttr; 1077 SMLoc trailingTypeLoc; 1078 1079 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1080 // case of an indirect call, there will be 1 operand before `(`. In case of a 1081 // direct call, there will be no operands and the parser will stop at the 1082 // function identifier without complaining. 1083 if (parser.parseOperandList(operands)) 1084 return failure(); 1085 bool isDirect = operands.empty(); 1086 1087 // Optionally parse a function identifier. 1088 if (isDirect) 1089 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1090 return failure(); 1091 1092 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1093 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1094 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1095 return failure(); 1096 1097 auto funcType = type.dyn_cast<FunctionType>(); 1098 if (!funcType) 1099 return parser.emitError(trailingTypeLoc, "expected function type"); 1100 if (funcType.getNumResults() > 1) 1101 return parser.emitError(trailingTypeLoc, 1102 "expected function with 0 or 1 result"); 1103 if (isDirect) { 1104 // Make sure types match. 1105 if (parser.resolveOperands(operands, funcType.getInputs(), 1106 parser.getNameLoc(), result.operands)) 1107 return failure(); 1108 if (funcType.getNumResults() != 0 && 1109 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1110 result.addTypes(funcType.getResults()); 1111 } else { 1112 Builder &builder = parser.getBuilder(); 1113 Type llvmResultType; 1114 if (funcType.getNumResults() == 0) { 1115 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1116 } else { 1117 llvmResultType = funcType.getResult(0); 1118 if (!isCompatibleType(llvmResultType)) 1119 return parser.emitError(trailingTypeLoc, 1120 "expected result to have LLVM type"); 1121 } 1122 1123 SmallVector<Type, 8> argTypes; 1124 argTypes.reserve(funcType.getNumInputs()); 1125 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1126 auto argType = funcType.getInput(i); 1127 if (!isCompatibleType(argType)) 1128 return parser.emitError(trailingTypeLoc, 1129 "expected LLVM types as inputs"); 1130 argTypes.push_back(argType); 1131 } 1132 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1133 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1134 1135 auto funcArguments = 1136 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1137 1138 // Make sure that the first operand (indirect callee) matches the wrapped 1139 // LLVM IR function type, and that the types of the other call operands 1140 // match the types of the function arguments. 1141 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1142 parser.resolveOperands(funcArguments, funcType.getInputs(), 1143 parser.getNameLoc(), result.operands)) 1144 return failure(); 1145 1146 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1147 result.addTypes(llvmResultType); 1148 } 1149 1150 return success(); 1151 } 1152 1153 //===----------------------------------------------------------------------===// 1154 // Printing/parsing for LLVM::ExtractElementOp. 1155 //===----------------------------------------------------------------------===// 1156 // Expects vector to be of wrapped LLVM vector type and position to be of 1157 // wrapped LLVM i32 type. 1158 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1159 Value vector, Value position, 1160 ArrayRef<NamedAttribute> attrs) { 1161 auto vectorType = vector.getType(); 1162 auto llvmType = LLVM::getVectorElementType(vectorType); 1163 build(b, result, llvmType, vector, position); 1164 result.addAttributes(attrs); 1165 } 1166 1167 void ExtractElementOp::print(OpAsmPrinter &p) { 1168 p << ' ' << getVector() << "[" << getPosition() << " : " 1169 << getPosition().getType() << "]"; 1170 p.printOptionalAttrDict((*this)->getAttrs()); 1171 p << " : " << getVector().getType(); 1172 } 1173 1174 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1175 // attribute-dict? `:` type 1176 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1177 OperationState &result) { 1178 SMLoc loc; 1179 OpAsmParser::UnresolvedOperand vector, position; 1180 Type type, positionType; 1181 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1182 parser.parseLSquare() || parser.parseOperand(position) || 1183 parser.parseColonType(positionType) || parser.parseRSquare() || 1184 parser.parseOptionalAttrDict(result.attributes) || 1185 parser.parseColonType(type) || 1186 parser.resolveOperand(vector, type, result.operands) || 1187 parser.resolveOperand(position, positionType, result.operands)) 1188 return failure(); 1189 if (!LLVM::isCompatibleVectorType(type)) 1190 return parser.emitError( 1191 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1192 result.addTypes(LLVM::getVectorElementType(type)); 1193 return success(); 1194 } 1195 1196 LogicalResult ExtractElementOp::verify() { 1197 Type vectorType = getVector().getType(); 1198 if (!LLVM::isCompatibleVectorType(vectorType)) 1199 return emitOpError("expected LLVM dialect-compatible vector type for " 1200 "operand #1, got") 1201 << vectorType; 1202 Type valueType = LLVM::getVectorElementType(vectorType); 1203 if (valueType != getRes().getType()) 1204 return emitOpError() << "Type mismatch: extracting from " << vectorType 1205 << " should produce " << valueType 1206 << " but this op returns " << getRes().getType(); 1207 return success(); 1208 } 1209 1210 //===----------------------------------------------------------------------===// 1211 // Printing/parsing for LLVM::ExtractValueOp. 1212 //===----------------------------------------------------------------------===// 1213 1214 void ExtractValueOp::print(OpAsmPrinter &p) { 1215 p << ' ' << getContainer() << getPosition(); 1216 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1217 p << " : " << getContainer().getType(); 1218 } 1219 1220 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1221 // `containerType`. Position is an integer array attribute where each value 1222 // is a zero-based position of the element in the aggregate type. Return the 1223 // resulting type wrapped in MLIR, or nullptr on error. 1224 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1225 Type containerType, 1226 ArrayAttr positionAttr, 1227 SMLoc attributeLoc, 1228 SMLoc typeLoc) { 1229 Type llvmType = containerType; 1230 if (!isCompatibleType(containerType)) 1231 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1232 1233 // Infer the element type from the structure type: iteratively step inside the 1234 // type by taking the element type, indexed by the position attribute for 1235 // structures. Check the position index before accessing, it is supposed to 1236 // be in bounds. 1237 for (Attribute subAttr : positionAttr) { 1238 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1239 if (!positionElementAttr) 1240 return parser.emitError(attributeLoc, 1241 "expected an array of integer literals"), 1242 nullptr; 1243 int position = positionElementAttr.getInt(); 1244 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1245 if (position < 0 || 1246 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1247 return parser.emitError(attributeLoc, "position out of bounds"), 1248 nullptr; 1249 llvmType = arrayType.getElementType(); 1250 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1251 if (position < 0 || 1252 static_cast<unsigned>(position) >= structType.getBody().size()) 1253 return parser.emitError(attributeLoc, "position out of bounds"), 1254 nullptr; 1255 llvmType = structType.getBody()[position]; 1256 } else { 1257 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1258 nullptr; 1259 } 1260 } 1261 return llvmType; 1262 } 1263 1264 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1265 // `containerType`. Returns null on failure. 1266 static Type getInsertExtractValueElementType(Type containerType, 1267 ArrayAttr positionAttr, 1268 Operation *op) { 1269 Type llvmType = containerType; 1270 if (!isCompatibleType(containerType)) { 1271 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1272 return {}; 1273 } 1274 1275 // Infer the element type from the structure type: iteratively step inside the 1276 // type by taking the element type, indexed by the position attribute for 1277 // structures. Check the position index before accessing, it is supposed to 1278 // be in bounds. 1279 for (Attribute subAttr : positionAttr) { 1280 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1281 if (!positionElementAttr) { 1282 op->emitOpError("expected an array of integer literals, got: ") 1283 << subAttr; 1284 return {}; 1285 } 1286 int position = positionElementAttr.getInt(); 1287 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1288 if (position < 0 || 1289 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1290 op->emitOpError("position out of bounds: ") << position; 1291 return {}; 1292 } 1293 llvmType = arrayType.getElementType(); 1294 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1295 if (position < 0 || 1296 static_cast<unsigned>(position) >= structType.getBody().size()) { 1297 op->emitOpError("position out of bounds") << position; 1298 return {}; 1299 } 1300 llvmType = structType.getBody()[position]; 1301 } else { 1302 op->emitOpError("expected LLVM IR structure/array type, got: ") 1303 << llvmType; 1304 return {}; 1305 } 1306 } 1307 return llvmType; 1308 } 1309 1310 // <operation> ::= `llvm.extractvalue` ssa-use 1311 // `[` integer-literal (`,` integer-literal)* `]` 1312 // attribute-dict? `:` type 1313 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1314 OpAsmParser::UnresolvedOperand container; 1315 Type containerType; 1316 ArrayAttr positionAttr; 1317 SMLoc attributeLoc, trailingTypeLoc; 1318 1319 if (parser.parseOperand(container) || 1320 parser.getCurrentLocation(&attributeLoc) || 1321 parser.parseAttribute(positionAttr, "position", result.attributes) || 1322 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1323 parser.getCurrentLocation(&trailingTypeLoc) || 1324 parser.parseType(containerType) || 1325 parser.resolveOperand(container, containerType, result.operands)) 1326 return failure(); 1327 1328 auto elementType = getInsertExtractValueElementType( 1329 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1330 if (!elementType) 1331 return failure(); 1332 1333 result.addTypes(elementType); 1334 return success(); 1335 } 1336 1337 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1338 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1339 OpFoldResult result = {}; 1340 while (insertValueOp) { 1341 if (getPosition() == insertValueOp.getPosition()) 1342 return insertValueOp.getValue(); 1343 unsigned min = 1344 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1345 // If one is fully prefix of the other, stop propagating back as it will 1346 // miss dependencies. For instance, %3 should not fold to %f0 in the 1347 // following example: 1348 // ``` 1349 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1350 // !llvm.array<4 x !llvm.array<4xf32>> 1351 // %2 = llvm.insertvalue %arr, %1[0] : 1352 // !llvm.array<4 x !llvm.array<4xf32>> 1353 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1354 // ``` 1355 if (getPosition().getValue().take_front(min) == 1356 insertValueOp.getPosition().getValue().take_front(min)) 1357 return result; 1358 1359 // If neither a prefix, nor the exact position, we can extract out of the 1360 // value being inserted into. Moreover, we can try again if that operand 1361 // is itself an insertvalue expression. 1362 getContainerMutable().assign(insertValueOp.getContainer()); 1363 result = getResult(); 1364 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1365 } 1366 return result; 1367 } 1368 1369 LogicalResult ExtractValueOp::verify() { 1370 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1371 getPositionAttr(), *this); 1372 if (!valueType) 1373 return failure(); 1374 1375 if (getRes().getType() != valueType) 1376 return emitOpError() << "Type mismatch: extracting from " 1377 << getContainer().getType() << " should produce " 1378 << valueType << " but this op returns " 1379 << getRes().getType(); 1380 return success(); 1381 } 1382 1383 //===----------------------------------------------------------------------===// 1384 // Printing/parsing for LLVM::InsertElementOp. 1385 //===----------------------------------------------------------------------===// 1386 1387 void InsertElementOp::print(OpAsmPrinter &p) { 1388 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1389 << getPosition().getType() << "]"; 1390 p.printOptionalAttrDict((*this)->getAttrs()); 1391 p << " : " << getVector().getType(); 1392 } 1393 1394 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1395 // attribute-dict? `:` type 1396 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1397 OperationState &result) { 1398 SMLoc loc; 1399 OpAsmParser::UnresolvedOperand vector, value, position; 1400 Type vectorType, positionType; 1401 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1402 parser.parseComma() || parser.parseOperand(vector) || 1403 parser.parseLSquare() || parser.parseOperand(position) || 1404 parser.parseColonType(positionType) || parser.parseRSquare() || 1405 parser.parseOptionalAttrDict(result.attributes) || 1406 parser.parseColonType(vectorType)) 1407 return failure(); 1408 1409 if (!LLVM::isCompatibleVectorType(vectorType)) 1410 return parser.emitError( 1411 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1412 Type valueType = LLVM::getVectorElementType(vectorType); 1413 if (!valueType) 1414 return failure(); 1415 1416 if (parser.resolveOperand(vector, vectorType, result.operands) || 1417 parser.resolveOperand(value, valueType, result.operands) || 1418 parser.resolveOperand(position, positionType, result.operands)) 1419 return failure(); 1420 1421 result.addTypes(vectorType); 1422 return success(); 1423 } 1424 1425 LogicalResult InsertElementOp::verify() { 1426 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1427 if (valueType != getValue().getType()) 1428 return emitOpError() << "Type mismatch: cannot insert " 1429 << getValue().getType() << " into " 1430 << getVector().getType(); 1431 return success(); 1432 } 1433 1434 //===----------------------------------------------------------------------===// 1435 // Printing/parsing for LLVM::InsertValueOp. 1436 //===----------------------------------------------------------------------===// 1437 1438 void InsertValueOp::print(OpAsmPrinter &p) { 1439 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1440 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1441 p << " : " << getContainer().getType(); 1442 } 1443 1444 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1445 // `[` integer-literal (`,` integer-literal)* `]` 1446 // attribute-dict? `:` type 1447 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1448 OpAsmParser::UnresolvedOperand container, value; 1449 Type containerType; 1450 ArrayAttr positionAttr; 1451 SMLoc attributeLoc, trailingTypeLoc; 1452 1453 if (parser.parseOperand(value) || parser.parseComma() || 1454 parser.parseOperand(container) || 1455 parser.getCurrentLocation(&attributeLoc) || 1456 parser.parseAttribute(positionAttr, "position", result.attributes) || 1457 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1458 parser.getCurrentLocation(&trailingTypeLoc) || 1459 parser.parseType(containerType)) 1460 return failure(); 1461 1462 auto valueType = getInsertExtractValueElementType( 1463 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1464 if (!valueType) 1465 return failure(); 1466 1467 if (parser.resolveOperand(container, containerType, result.operands) || 1468 parser.resolveOperand(value, valueType, result.operands)) 1469 return failure(); 1470 1471 result.addTypes(containerType); 1472 return success(); 1473 } 1474 1475 LogicalResult InsertValueOp::verify() { 1476 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1477 getPositionAttr(), *this); 1478 if (!valueType) 1479 return failure(); 1480 1481 if (getValue().getType() != valueType) 1482 return emitOpError() << "Type mismatch: cannot insert " 1483 << getValue().getType() << " into " 1484 << getContainer().getType(); 1485 1486 return success(); 1487 } 1488 1489 //===----------------------------------------------------------------------===// 1490 // Printing, parsing and verification for LLVM::ReturnOp. 1491 //===----------------------------------------------------------------------===// 1492 1493 LogicalResult ReturnOp::verify() { 1494 if (getNumOperands() > 1) 1495 return emitOpError("expected at most 1 operand"); 1496 1497 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1498 Type expectedType = parent.getFunctionType().getReturnType(); 1499 if (expectedType.isa<LLVMVoidType>()) { 1500 if (getNumOperands() == 0) 1501 return success(); 1502 InFlightDiagnostic diag = emitOpError("expected no operands"); 1503 diag.attachNote(parent->getLoc()) << "when returning from function"; 1504 return diag; 1505 } 1506 if (getNumOperands() == 0) { 1507 if (expectedType.isa<LLVMVoidType>()) 1508 return success(); 1509 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1510 diag.attachNote(parent->getLoc()) << "when returning from function"; 1511 return diag; 1512 } 1513 if (expectedType != getOperand(0).getType()) { 1514 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1515 diag.attachNote(parent->getLoc()) << "when returning from function"; 1516 return diag; 1517 } 1518 } 1519 return success(); 1520 } 1521 1522 //===----------------------------------------------------------------------===// 1523 // ResumeOp 1524 //===----------------------------------------------------------------------===// 1525 1526 LogicalResult ResumeOp::verify() { 1527 if (!getValue().getDefiningOp<LandingpadOp>()) 1528 return emitOpError("expects landingpad value as operand"); 1529 // No check for personality of function - landingpad op verifies it. 1530 return success(); 1531 } 1532 1533 //===----------------------------------------------------------------------===// 1534 // Verifier for LLVM::AddressOfOp. 1535 //===----------------------------------------------------------------------===// 1536 1537 template <typename OpTy> 1538 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1539 Operation *module = parent; 1540 while (module && !satisfiesLLVMModule(module)) 1541 module = module->getParentOp(); 1542 assert(module && "unexpected operation outside of a module"); 1543 return dyn_cast_or_null<OpTy>( 1544 mlir::SymbolTable::lookupSymbolIn(module, name)); 1545 } 1546 1547 GlobalOp AddressOfOp::getGlobal() { 1548 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1549 getGlobalName()); 1550 } 1551 1552 LLVMFuncOp AddressOfOp::getFunction() { 1553 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1554 getGlobalName()); 1555 } 1556 1557 LogicalResult AddressOfOp::verify() { 1558 auto global = getGlobal(); 1559 auto function = getFunction(); 1560 if (!global && !function) 1561 return emitOpError( 1562 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1563 1564 if (global && 1565 LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) != 1566 getResult().getType()) 1567 return emitOpError( 1568 "the type must be a pointer to the type of the referenced global"); 1569 1570 if (function && LLVM::LLVMPointerType::get(function.getFunctionType()) != 1571 getResult().getType()) 1572 return emitOpError( 1573 "the type must be a pointer to the type of the referenced function"); 1574 1575 return success(); 1576 } 1577 1578 //===----------------------------------------------------------------------===// 1579 // Builder, printer and verifier for LLVM::GlobalOp. 1580 //===----------------------------------------------------------------------===// 1581 1582 /// Returns the name used for the linkage attribute. This *must* correspond to 1583 /// the name of the attribute in ODS. 1584 static StringRef getLinkageAttrName() { return "linkage"; } 1585 1586 /// Returns the name used for the unnamed_addr attribute. This *must* correspond 1587 /// to the name of the attribute in ODS. 1588 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; } 1589 1590 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1591 bool isConstant, Linkage linkage, StringRef name, 1592 Attribute value, uint64_t alignment, unsigned addrSpace, 1593 bool dsoLocal, ArrayRef<NamedAttribute> attrs) { 1594 result.addAttribute(SymbolTable::getSymbolAttrName(), 1595 builder.getStringAttr(name)); 1596 result.addAttribute("global_type", TypeAttr::get(type)); 1597 if (isConstant) 1598 result.addAttribute("constant", builder.getUnitAttr()); 1599 if (value) 1600 result.addAttribute("value", value); 1601 if (dsoLocal) 1602 result.addAttribute("dso_local", builder.getUnitAttr()); 1603 1604 // Only add an alignment attribute if the "alignment" input 1605 // is different from 0. The value must also be a power of two, but 1606 // this is tested in GlobalOp::verify, not here. 1607 if (alignment != 0) 1608 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 1609 1610 result.addAttribute(::getLinkageAttrName(), 1611 LinkageAttr::get(builder.getContext(), linkage)); 1612 if (addrSpace != 0) 1613 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); 1614 result.attributes.append(attrs.begin(), attrs.end()); 1615 result.addRegion(); 1616 } 1617 1618 void GlobalOp::print(OpAsmPrinter &p) { 1619 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1620 if (auto unnamedAddr = getUnnamedAddr()) { 1621 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1622 if (!str.empty()) 1623 p << str << ' '; 1624 } 1625 if (getConstant()) 1626 p << "constant "; 1627 p.printSymbolName(getSymName()); 1628 p << '('; 1629 if (auto value = getValueOrNull()) 1630 p.printAttribute(value); 1631 p << ')'; 1632 // Note that the alignment attribute is printed using the 1633 // default syntax here, even though it is an inherent attribute 1634 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1635 p.printOptionalAttrDict((*this)->getAttrs(), 1636 {SymbolTable::getSymbolAttrName(), "global_type", 1637 "constant", "value", getLinkageAttrName(), 1638 getUnnamedAddrAttrName()}); 1639 1640 // Print the trailing type unless it's a string global. 1641 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1642 return; 1643 p << " : " << getType(); 1644 1645 Region &initializer = getInitializerRegion(); 1646 if (!initializer.empty()) { 1647 p << ' '; 1648 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1649 } 1650 } 1651 1652 // Parses one of the keywords provided in the list `keywords` and returns the 1653 // position of the parsed keyword in the list. If none of the keywords from the 1654 // list is parsed, returns -1. 1655 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1656 ArrayRef<StringRef> keywords) { 1657 for (const auto &en : llvm::enumerate(keywords)) { 1658 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1659 return en.index(); 1660 } 1661 return -1; 1662 } 1663 1664 namespace { 1665 template <typename Ty> 1666 struct EnumTraits {}; 1667 1668 #define REGISTER_ENUM_TYPE(Ty) \ 1669 template <> \ 1670 struct EnumTraits<Ty> { \ 1671 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1672 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1673 } 1674 1675 REGISTER_ENUM_TYPE(Linkage); 1676 REGISTER_ENUM_TYPE(UnnamedAddr); 1677 } // namespace 1678 1679 /// Parse an enum from the keyword, or default to the provided default value. 1680 /// The return type is the enum type by default, unless overriden with the 1681 /// second template argument. 1682 template <typename EnumTy, typename RetTy = EnumTy> 1683 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1684 OperationState &result, 1685 EnumTy defaultValue) { 1686 SmallVector<StringRef, 10> names; 1687 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1688 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1689 1690 int index = parseOptionalKeywordAlternative(parser, names); 1691 if (index == -1) 1692 return static_cast<RetTy>(defaultValue); 1693 return static_cast<RetTy>(index); 1694 } 1695 1696 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1697 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1698 // align ::= `align` `=` UINT64 1699 // 1700 // The type can be omitted for string attributes, in which case it will be 1701 // inferred from the value of the string as [strlen(value) x i8]. 1702 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1703 MLIRContext *ctx = parser.getContext(); 1704 // Parse optional linkage, default to External. 1705 result.addAttribute(::getLinkageAttrName(), 1706 LLVM::LinkageAttr::get( 1707 ctx, parseOptionalLLVMKeyword<Linkage>( 1708 parser, result, LLVM::Linkage::External))); 1709 // Parse optional UnnamedAddr, default to None. 1710 result.addAttribute(::getUnnamedAddrAttrName(), 1711 parser.getBuilder().getI64IntegerAttr( 1712 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1713 parser, result, LLVM::UnnamedAddr::None))); 1714 1715 if (succeeded(parser.parseOptionalKeyword("constant"))) 1716 result.addAttribute("constant", parser.getBuilder().getUnitAttr()); 1717 1718 StringAttr name; 1719 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), 1720 result.attributes) || 1721 parser.parseLParen()) 1722 return failure(); 1723 1724 Attribute value; 1725 if (parser.parseOptionalRParen()) { 1726 if (parser.parseAttribute(value, "value", result.attributes) || 1727 parser.parseRParen()) 1728 return failure(); 1729 } 1730 1731 SmallVector<Type, 1> types; 1732 if (parser.parseOptionalAttrDict(result.attributes) || 1733 parser.parseOptionalColonTypeList(types)) 1734 return failure(); 1735 1736 if (types.size() > 1) 1737 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1738 1739 Region &initRegion = *result.addRegion(); 1740 if (types.empty()) { 1741 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1742 MLIRContext *context = parser.getContext(); 1743 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1744 strAttr.getValue().size()); 1745 types.push_back(arrayType); 1746 } else { 1747 return parser.emitError(parser.getNameLoc(), 1748 "type can only be omitted for string globals"); 1749 } 1750 } else { 1751 OptionalParseResult parseResult = 1752 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1753 /*argTypes=*/{}); 1754 if (parseResult.hasValue() && failed(*parseResult)) 1755 return failure(); 1756 } 1757 1758 result.addAttribute("global_type", TypeAttr::get(types[0])); 1759 return success(); 1760 } 1761 1762 static bool isZeroAttribute(Attribute value) { 1763 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1764 return intValue.getValue().isNullValue(); 1765 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1766 return fpValue.getValue().isZero(); 1767 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1768 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1769 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1770 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1771 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1772 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1773 return false; 1774 } 1775 1776 LogicalResult GlobalOp::verify() { 1777 if (!LLVMPointerType::isValidElementType(getType())) 1778 return emitOpError( 1779 "expects type to be a valid element type for an LLVM pointer"); 1780 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1781 return emitOpError("must appear at the module level"); 1782 1783 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1784 auto type = getType().dyn_cast<LLVMArrayType>(); 1785 IntegerType elementType = 1786 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1787 if (!elementType || elementType.getWidth() != 8 || 1788 type.getNumElements() != strAttr.getValue().size()) 1789 return emitOpError( 1790 "requires an i8 array type of the length equal to that of the string " 1791 "attribute"); 1792 } 1793 1794 if (getLinkage() == Linkage::Common) { 1795 if (Attribute value = getValueOrNull()) { 1796 if (!isZeroAttribute(value)) { 1797 return emitOpError() 1798 << "expected zero value for '" 1799 << stringifyLinkage(Linkage::Common) << "' linkage"; 1800 } 1801 } 1802 } 1803 1804 if (getLinkage() == Linkage::Appending) { 1805 if (!getType().isa<LLVMArrayType>()) { 1806 return emitOpError() << "expected array type for '" 1807 << stringifyLinkage(Linkage::Appending) 1808 << "' linkage"; 1809 } 1810 } 1811 1812 Optional<uint64_t> alignAttr = getAlignment(); 1813 if (alignAttr.hasValue()) { 1814 uint64_t value = alignAttr.getValue(); 1815 if (!llvm::isPowerOf2_64(value)) 1816 return emitError() << "alignment attribute is not a power of 2"; 1817 } 1818 1819 return success(); 1820 } 1821 1822 LogicalResult GlobalOp::verifyRegions() { 1823 if (Block *b = getInitializerBlock()) { 1824 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1825 if (ret.operand_type_begin() == ret.operand_type_end()) 1826 return emitOpError("initializer region cannot return void"); 1827 if (*ret.operand_type_begin() != getType()) 1828 return emitOpError("initializer region type ") 1829 << *ret.operand_type_begin() << " does not match global type " 1830 << getType(); 1831 1832 for (Operation &op : *b) { 1833 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 1834 if (!iface || !iface.hasNoEffect()) 1835 return op.emitError() 1836 << "ops with side effects not allowed in global initializers"; 1837 } 1838 1839 if (getValueOrNull()) 1840 return emitOpError("cannot have both initializer value and region"); 1841 } 1842 1843 return success(); 1844 } 1845 1846 //===----------------------------------------------------------------------===// 1847 // LLVM::GlobalCtorsOp 1848 //===----------------------------------------------------------------------===// 1849 1850 LogicalResult 1851 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1852 for (Attribute ctor : getCtors()) { 1853 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 1854 symbolTable))) 1855 return failure(); 1856 } 1857 return success(); 1858 } 1859 1860 LogicalResult GlobalCtorsOp::verify() { 1861 if (getCtors().size() != getPriorities().size()) 1862 return emitError( 1863 "mismatch between the number of ctors and the number of priorities"); 1864 return success(); 1865 } 1866 1867 //===----------------------------------------------------------------------===// 1868 // LLVM::GlobalDtorsOp 1869 //===----------------------------------------------------------------------===// 1870 1871 LogicalResult 1872 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1873 for (Attribute dtor : getDtors()) { 1874 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 1875 symbolTable))) 1876 return failure(); 1877 } 1878 return success(); 1879 } 1880 1881 LogicalResult GlobalDtorsOp::verify() { 1882 if (getDtors().size() != getPriorities().size()) 1883 return emitError( 1884 "mismatch between the number of dtors and the number of priorities"); 1885 return success(); 1886 } 1887 1888 //===----------------------------------------------------------------------===// 1889 // Printing/parsing for LLVM::ShuffleVectorOp. 1890 //===----------------------------------------------------------------------===// 1891 // Expects vector to be of wrapped LLVM vector type and position to be of 1892 // wrapped LLVM i32 type. 1893 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 1894 Value v1, Value v2, ArrayAttr mask, 1895 ArrayRef<NamedAttribute> attrs) { 1896 auto containerType = v1.getType(); 1897 auto vType = LLVM::getVectorType( 1898 LLVM::getVectorElementType(containerType), mask.size(), 1899 containerType.cast<VectorType>().isScalable()); 1900 build(b, result, vType, v1, v2, mask); 1901 result.addAttributes(attrs); 1902 } 1903 1904 void ShuffleVectorOp::print(OpAsmPrinter &p) { 1905 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 1906 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 1907 p << " : " << getV1().getType() << ", " << getV2().getType(); 1908 } 1909 1910 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 1911 // `[` integer-literal (`,` integer-literal)* `]` 1912 // attribute-dict? `:` type 1913 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 1914 OperationState &result) { 1915 SMLoc loc; 1916 OpAsmParser::UnresolvedOperand v1, v2; 1917 ArrayAttr maskAttr; 1918 Type typeV1, typeV2; 1919 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 1920 parser.parseComma() || parser.parseOperand(v2) || 1921 parser.parseAttribute(maskAttr, "mask", result.attributes) || 1922 parser.parseOptionalAttrDict(result.attributes) || 1923 parser.parseColonType(typeV1) || parser.parseComma() || 1924 parser.parseType(typeV2) || 1925 parser.resolveOperand(v1, typeV1, result.operands) || 1926 parser.resolveOperand(v2, typeV2, result.operands)) 1927 return failure(); 1928 if (!LLVM::isCompatibleVectorType(typeV1)) 1929 return parser.emitError( 1930 loc, "expected LLVM IR dialect vector type for operand #1"); 1931 auto vType = 1932 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 1933 typeV1.cast<VectorType>().isScalable()); 1934 result.addTypes(vType); 1935 return success(); 1936 } 1937 1938 LogicalResult ShuffleVectorOp::verify() { 1939 Type type1 = getV1().getType(); 1940 Type type2 = getV2().getType(); 1941 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 1942 return emitOpError("expected matching LLVM IR Dialect element types"); 1943 if (LLVM::isScalableVectorType(type1)) 1944 if (llvm::any_of(getMask(), [](Attribute attr) { 1945 return attr.cast<IntegerAttr>().getInt() != 0; 1946 })) 1947 return emitOpError("expected a splat operation for scalable vectors"); 1948 return success(); 1949 } 1950 1951 //===----------------------------------------------------------------------===// 1952 // Implementations for LLVM::LLVMFuncOp. 1953 //===----------------------------------------------------------------------===// 1954 1955 // Add the entry block to the function. 1956 Block *LLVMFuncOp::addEntryBlock() { 1957 assert(empty() && "function already has an entry block"); 1958 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 1959 1960 auto *entry = new Block; 1961 push_back(entry); 1962 1963 // FIXME: Allow passing in proper locations for the entry arguments. 1964 LLVMFunctionType type = getFunctionType(); 1965 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 1966 entry->addArgument(type.getParamType(i), getLoc()); 1967 return entry; 1968 } 1969 1970 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 1971 StringRef name, Type type, LLVM::Linkage linkage, 1972 bool dsoLocal, ArrayRef<NamedAttribute> attrs, 1973 ArrayRef<DictionaryAttr> argAttrs) { 1974 result.addRegion(); 1975 result.addAttribute(SymbolTable::getSymbolAttrName(), 1976 builder.getStringAttr(name)); 1977 result.addAttribute(getFunctionTypeAttrName(result.name), 1978 TypeAttr::get(type)); 1979 result.addAttribute(::getLinkageAttrName(), 1980 LinkageAttr::get(builder.getContext(), linkage)); 1981 result.attributes.append(attrs.begin(), attrs.end()); 1982 if (dsoLocal) 1983 result.addAttribute("dso_local", builder.getUnitAttr()); 1984 if (argAttrs.empty()) 1985 return; 1986 1987 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 1988 "expected as many argument attribute lists as arguments"); 1989 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 1990 /*resultAttrs=*/llvm::None); 1991 } 1992 1993 // Builds an LLVM function type from the given lists of input and output types. 1994 // Returns a null type if any of the types provided are non-LLVM types, or if 1995 // there is more than one output type. 1996 static Type 1997 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 1998 ArrayRef<Type> outputs, 1999 function_interface_impl::VariadicFlag variadicFlag) { 2000 Builder &b = parser.getBuilder(); 2001 if (outputs.size() > 1) { 2002 parser.emitError(loc, "failed to construct function type: expected zero or " 2003 "one function result"); 2004 return {}; 2005 } 2006 2007 // Convert inputs to LLVM types, exit early on error. 2008 SmallVector<Type, 4> llvmInputs; 2009 for (auto t : inputs) { 2010 if (!isCompatibleType(t)) { 2011 parser.emitError(loc, "failed to construct function type: expected LLVM " 2012 "type for function arguments"); 2013 return {}; 2014 } 2015 llvmInputs.push_back(t); 2016 } 2017 2018 // No output is denoted as "void" in LLVM type system. 2019 Type llvmOutput = 2020 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2021 if (!isCompatibleType(llvmOutput)) { 2022 parser.emitError(loc, "failed to construct function type: expected LLVM " 2023 "type for function results") 2024 << llvmOutput; 2025 return {}; 2026 } 2027 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2028 variadicFlag.isVariadic()); 2029 } 2030 2031 // Parses an LLVM function. 2032 // 2033 // operation ::= `llvm.func` linkage? function-signature function-attributes? 2034 // function-body 2035 // 2036 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2037 // Default to external linkage if no keyword is provided. 2038 result.addAttribute( 2039 ::getLinkageAttrName(), 2040 LinkageAttr::get(parser.getContext(), 2041 parseOptionalLLVMKeyword<Linkage>( 2042 parser, result, LLVM::Linkage::External))); 2043 2044 StringAttr nameAttr; 2045 SmallVector<OpAsmParser::UnresolvedOperand> entryArgs; 2046 SmallVector<NamedAttrList> argAttrs; 2047 SmallVector<NamedAttrList> resultAttrs; 2048 SmallVector<Type> argTypes; 2049 SmallVector<Type> resultTypes; 2050 SmallVector<Location> argLocations; 2051 bool isVariadic; 2052 2053 auto signatureLocation = parser.getCurrentLocation(); 2054 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2055 result.attributes) || 2056 function_interface_impl::parseFunctionSignature( 2057 parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, 2058 argLocations, isVariadic, resultTypes, resultAttrs)) 2059 return failure(); 2060 2061 auto type = 2062 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2063 function_interface_impl::VariadicFlag(isVariadic)); 2064 if (!type) 2065 return failure(); 2066 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2067 TypeAttr::get(type)); 2068 2069 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2070 return failure(); 2071 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2072 argAttrs, resultAttrs); 2073 2074 auto *body = result.addRegion(); 2075 OptionalParseResult parseResult = parser.parseOptionalRegion( 2076 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 2077 return failure(parseResult.hasValue() && failed(*parseResult)); 2078 } 2079 2080 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2081 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2082 // the external linkage since it is the default value. 2083 void LLVMFuncOp::print(OpAsmPrinter &p) { 2084 p << ' '; 2085 if (getLinkage() != LLVM::Linkage::External) 2086 p << stringifyLinkage(getLinkage()) << ' '; 2087 p.printSymbolName(getName()); 2088 2089 LLVMFunctionType fnType = getFunctionType(); 2090 SmallVector<Type, 8> argTypes; 2091 SmallVector<Type, 1> resTypes; 2092 argTypes.reserve(fnType.getNumParams()); 2093 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2094 argTypes.push_back(fnType.getParamType(i)); 2095 2096 Type returnType = fnType.getReturnType(); 2097 if (!returnType.isa<LLVMVoidType>()) 2098 resTypes.push_back(returnType); 2099 2100 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2101 isVarArg(), resTypes); 2102 function_interface_impl::printFunctionAttributes( 2103 p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); 2104 2105 // Print the body if this is not an external function. 2106 Region &body = getBody(); 2107 if (!body.empty()) { 2108 p << ' '; 2109 p.printRegion(body, /*printEntryBlockArgs=*/false, 2110 /*printBlockTerminators=*/true); 2111 } 2112 } 2113 2114 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2115 // - functions don't have 'common' linkage 2116 // - external functions have 'external' or 'extern_weak' linkage; 2117 // - vararg is (currently) only supported for external functions; 2118 LogicalResult LLVMFuncOp::verify() { 2119 if (getLinkage() == LLVM::Linkage::Common) 2120 return emitOpError() << "functions cannot have '" 2121 << stringifyLinkage(LLVM::Linkage::Common) 2122 << "' linkage"; 2123 2124 // Check to see if this function has a void return with a result attribute to 2125 // it. It isn't clear what semantics we would assign to that. 2126 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2127 !getResultAttrs(0).empty()) { 2128 return emitOpError() 2129 << "cannot attach result attributes to functions with a void return"; 2130 } 2131 2132 if (isExternal()) { 2133 if (getLinkage() != LLVM::Linkage::External && 2134 getLinkage() != LLVM::Linkage::ExternWeak) 2135 return emitOpError() << "external functions must have '" 2136 << stringifyLinkage(LLVM::Linkage::External) 2137 << "' or '" 2138 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2139 << "' linkage"; 2140 return success(); 2141 } 2142 2143 if (isVarArg()) 2144 return emitOpError("only external functions can be variadic"); 2145 2146 return success(); 2147 } 2148 2149 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2150 /// - entry block arguments are of LLVM types. 2151 LogicalResult LLVMFuncOp::verifyRegions() { 2152 if (isExternal()) 2153 return success(); 2154 2155 unsigned numArguments = getFunctionType().getNumParams(); 2156 Block &entryBlock = front(); 2157 for (unsigned i = 0; i < numArguments; ++i) { 2158 Type argType = entryBlock.getArgument(i).getType(); 2159 if (!isCompatibleType(argType)) 2160 return emitOpError("entry block argument #") 2161 << i << " is not of LLVM type"; 2162 } 2163 2164 return success(); 2165 } 2166 2167 //===----------------------------------------------------------------------===// 2168 // Verification for LLVM::ConstantOp. 2169 //===----------------------------------------------------------------------===// 2170 2171 LogicalResult LLVM::ConstantOp::verify() { 2172 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2173 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2174 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2175 !arrayType.getElementType().isInteger(8)) { 2176 return emitOpError() << "expected array type of " 2177 << sAttr.getValue().size() 2178 << " i8 elements for the string constant"; 2179 } 2180 return success(); 2181 } 2182 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2183 if (structType.getBody().size() != 2 || 2184 structType.getBody()[0] != structType.getBody()[1]) { 2185 return emitError() << "expected struct type with two elements of the " 2186 "same type, the type of a complex constant"; 2187 } 2188 2189 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2190 if (!arrayAttr || arrayAttr.size() != 2 || 2191 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2192 return emitOpError() << "expected array attribute with two elements, " 2193 "representing a complex constant"; 2194 } 2195 2196 Type elementType = structType.getBody()[0]; 2197 if (!elementType 2198 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2199 return emitError() 2200 << "expected struct element types to be floating point type or " 2201 "integer type"; 2202 } 2203 return success(); 2204 } 2205 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2206 return emitOpError() 2207 << "only supports integer, float, string or elements attributes"; 2208 return success(); 2209 } 2210 2211 // Constant op constant-folds to its value. 2212 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2213 2214 //===----------------------------------------------------------------------===// 2215 // Utility functions for parsing atomic ops 2216 //===----------------------------------------------------------------------===// 2217 2218 // Helper function to parse a keyword into the specified attribute named by 2219 // `attrName`. The keyword must match one of the string values defined by the 2220 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2221 // state. 2222 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2223 StringRef attrName) { 2224 SMLoc loc; 2225 StringRef keyword; 2226 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2227 return failure(); 2228 2229 // Replace the keyword `keyword` with an integer attribute. 2230 auto kind = symbolizeAtomicBinOp(keyword); 2231 if (!kind) { 2232 return parser.emitError(loc) 2233 << "'" << keyword << "' is an incorrect value of the '" << attrName 2234 << "' attribute"; 2235 } 2236 2237 auto value = static_cast<int64_t>(kind.getValue()); 2238 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2239 result.addAttribute(attrName, attr); 2240 2241 return success(); 2242 } 2243 2244 // Helper function to parse a keyword into the specified attribute named by 2245 // `attrName`. The keyword must match one of the string values defined by the 2246 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2247 // state. 2248 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2249 OperationState &result, 2250 StringRef attrName) { 2251 SMLoc loc; 2252 StringRef ordering; 2253 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2254 return failure(); 2255 2256 // Replace the keyword `ordering` with an integer attribute. 2257 auto kind = symbolizeAtomicOrdering(ordering); 2258 if (!kind) { 2259 return parser.emitError(loc) 2260 << "'" << ordering << "' is an incorrect value of the '" << attrName 2261 << "' attribute"; 2262 } 2263 2264 auto value = static_cast<int64_t>(kind.getValue()); 2265 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2266 result.addAttribute(attrName, attr); 2267 2268 return success(); 2269 } 2270 2271 //===----------------------------------------------------------------------===// 2272 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2273 //===----------------------------------------------------------------------===// 2274 2275 void AtomicRMWOp::print(OpAsmPrinter &p) { 2276 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2277 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2278 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2279 p << " : " << getRes().getType(); 2280 } 2281 2282 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2283 // attribute-dict? `:` type 2284 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2285 Type type; 2286 OpAsmParser::UnresolvedOperand ptr, val; 2287 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2288 parser.parseComma() || parser.parseOperand(val) || 2289 parseAtomicOrdering(parser, result, "ordering") || 2290 parser.parseOptionalAttrDict(result.attributes) || 2291 parser.parseColonType(type) || 2292 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2293 result.operands) || 2294 parser.resolveOperand(val, type, result.operands)) 2295 return failure(); 2296 2297 result.addTypes(type); 2298 return success(); 2299 } 2300 2301 LogicalResult AtomicRMWOp::verify() { 2302 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2303 auto valType = getVal().getType(); 2304 if (valType != ptrType.getElementType()) 2305 return emitOpError("expected LLVM IR element type for operand #0 to " 2306 "match type for operand #1"); 2307 auto resType = getRes().getType(); 2308 if (resType != valType) 2309 return emitOpError( 2310 "expected LLVM IR result type to match type for operand #1"); 2311 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2312 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2313 return emitOpError("expected LLVM IR floating point type"); 2314 } else if (getBinOp() == AtomicBinOp::xchg) { 2315 auto intType = valType.dyn_cast<IntegerType>(); 2316 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2317 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2318 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2319 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2320 !valType.isa<Float64Type>()) 2321 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2322 } else { 2323 auto intType = valType.dyn_cast<IntegerType>(); 2324 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2325 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2326 intBitWidth != 64) 2327 return emitOpError("expected LLVM IR integer type"); 2328 } 2329 2330 if (static_cast<unsigned>(getOrdering()) < 2331 static_cast<unsigned>(AtomicOrdering::monotonic)) 2332 return emitOpError() << "expected at least '" 2333 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2334 << "' ordering"; 2335 2336 return success(); 2337 } 2338 2339 //===----------------------------------------------------------------------===// 2340 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2341 //===----------------------------------------------------------------------===// 2342 2343 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2344 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2345 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2346 << stringifyAtomicOrdering(getFailureOrdering()); 2347 p.printOptionalAttrDict((*this)->getAttrs(), 2348 {"success_ordering", "failure_ordering"}); 2349 p << " : " << getVal().getType(); 2350 } 2351 2352 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2353 // keyword keyword attribute-dict? `:` type 2354 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2355 OperationState &result) { 2356 auto &builder = parser.getBuilder(); 2357 Type type; 2358 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2359 if (parser.parseOperand(ptr) || parser.parseComma() || 2360 parser.parseOperand(cmp) || parser.parseComma() || 2361 parser.parseOperand(val) || 2362 parseAtomicOrdering(parser, result, "success_ordering") || 2363 parseAtomicOrdering(parser, result, "failure_ordering") || 2364 parser.parseOptionalAttrDict(result.attributes) || 2365 parser.parseColonType(type) || 2366 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2367 result.operands) || 2368 parser.resolveOperand(cmp, type, result.operands) || 2369 parser.resolveOperand(val, type, result.operands)) 2370 return failure(); 2371 2372 auto boolType = IntegerType::get(builder.getContext(), 1); 2373 auto resultType = 2374 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2375 result.addTypes(resultType); 2376 2377 return success(); 2378 } 2379 2380 LogicalResult AtomicCmpXchgOp::verify() { 2381 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2382 if (!ptrType) 2383 return emitOpError("expected LLVM IR pointer type for operand #0"); 2384 auto cmpType = getCmp().getType(); 2385 auto valType = getVal().getType(); 2386 if (cmpType != ptrType.getElementType() || cmpType != valType) 2387 return emitOpError("expected LLVM IR element type for operand #0 to " 2388 "match type for all other operands"); 2389 auto intType = valType.dyn_cast<IntegerType>(); 2390 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2391 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2392 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2393 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2394 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2395 return emitOpError("unexpected LLVM IR type"); 2396 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2397 getFailureOrdering() < AtomicOrdering::monotonic) 2398 return emitOpError("ordering must be at least 'monotonic'"); 2399 if (getFailureOrdering() == AtomicOrdering::release || 2400 getFailureOrdering() == AtomicOrdering::acq_rel) 2401 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2402 return success(); 2403 } 2404 2405 //===----------------------------------------------------------------------===// 2406 // Printer, parser and verifier for LLVM::FenceOp. 2407 //===----------------------------------------------------------------------===// 2408 2409 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2410 // attribute-dict? 2411 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2412 StringAttr sScope; 2413 StringRef syncscopeKeyword = "syncscope"; 2414 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2415 if (parser.parseLParen() || 2416 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2417 parser.parseRParen()) 2418 return failure(); 2419 } else { 2420 result.addAttribute(syncscopeKeyword, 2421 parser.getBuilder().getStringAttr("")); 2422 } 2423 if (parseAtomicOrdering(parser, result, "ordering") || 2424 parser.parseOptionalAttrDict(result.attributes)) 2425 return failure(); 2426 return success(); 2427 } 2428 2429 void FenceOp::print(OpAsmPrinter &p) { 2430 StringRef syncscopeKeyword = "syncscope"; 2431 p << ' '; 2432 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2433 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2434 p << stringifyAtomicOrdering(getOrdering()); 2435 } 2436 2437 LogicalResult FenceOp::verify() { 2438 if (getOrdering() == AtomicOrdering::not_atomic || 2439 getOrdering() == AtomicOrdering::unordered || 2440 getOrdering() == AtomicOrdering::monotonic) 2441 return emitOpError("can be given only acquire, release, acq_rel, " 2442 "and seq_cst orderings"); 2443 return success(); 2444 } 2445 2446 //===----------------------------------------------------------------------===// 2447 // Folder for LLVM::BitcastOp 2448 //===----------------------------------------------------------------------===// 2449 2450 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2451 // bitcast(x : T0, T0) -> x 2452 if (getArg().getType() == getType()) 2453 return getArg(); 2454 // bitcast(bitcast(x : T0, T1), T0) -> x 2455 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2456 if (prev.getArg().getType() == getType()) 2457 return prev.getArg(); 2458 return {}; 2459 } 2460 2461 //===----------------------------------------------------------------------===// 2462 // Folder for LLVM::AddrSpaceCastOp 2463 //===----------------------------------------------------------------------===// 2464 2465 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2466 // addrcast(x : T0, T0) -> x 2467 if (getArg().getType() == getType()) 2468 return getArg(); 2469 // addrcast(addrcast(x : T0, T1), T0) -> x 2470 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2471 if (prev.getArg().getType() == getType()) 2472 return prev.getArg(); 2473 return {}; 2474 } 2475 2476 //===----------------------------------------------------------------------===// 2477 // Folder for LLVM::GEPOp 2478 //===----------------------------------------------------------------------===// 2479 2480 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2481 // gep %x:T, 0 -> %x 2482 if (getBase().getType() == getType() && getIndices().size() == 1 && 2483 matchPattern(getIndices()[0], m_Zero())) 2484 return getBase(); 2485 return {}; 2486 } 2487 2488 //===----------------------------------------------------------------------===// 2489 // LLVMDialect initialization, type parsing, and registration. 2490 //===----------------------------------------------------------------------===// 2491 2492 void LLVMDialect::initialize() { 2493 addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>(); 2494 2495 // clang-format off 2496 addTypes<LLVMVoidType, 2497 LLVMPPCFP128Type, 2498 LLVMX86MMXType, 2499 LLVMTokenType, 2500 LLVMLabelType, 2501 LLVMMetadataType, 2502 LLVMFunctionType, 2503 LLVMPointerType, 2504 LLVMFixedVectorType, 2505 LLVMScalableVectorType, 2506 LLVMArrayType, 2507 LLVMStructType>(); 2508 // clang-format on 2509 addOperations< 2510 #define GET_OP_LIST 2511 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2512 >(); 2513 2514 // Support unknown operations because not all LLVM operations are registered. 2515 allowUnknownOperations(); 2516 } 2517 2518 #define GET_OP_CLASSES 2519 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2520 2521 /// Parse a type registered to this dialect. 2522 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2523 return detail::parseType(parser); 2524 } 2525 2526 /// Print a type registered to this dialect. 2527 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2528 return detail::printType(type, os); 2529 } 2530 2531 LogicalResult LLVMDialect::verifyDataLayoutString( 2532 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2533 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2534 llvm::DataLayout::parse(descr); 2535 if (maybeDataLayout) 2536 return success(); 2537 2538 std::string message; 2539 llvm::raw_string_ostream messageStream(message); 2540 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2541 reportError("invalid data layout descriptor: " + messageStream.str()); 2542 return failure(); 2543 } 2544 2545 /// Verify LLVM dialect attributes. 2546 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2547 NamedAttribute attr) { 2548 // If the `llvm.loop` attribute is present, enforce the following structure, 2549 // which the module translation can assume. 2550 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2551 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2552 if (!loopAttr) 2553 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2554 << "' to be a dictionary attribute"; 2555 Optional<NamedAttribute> parallelAccessGroup = 2556 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2557 if (parallelAccessGroup.hasValue()) { 2558 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2559 if (!accessGroups) 2560 return op->emitOpError() 2561 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2562 << "' to be an array attribute"; 2563 for (Attribute attr : accessGroups) { 2564 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2565 if (!accessGroupRef) 2566 return op->emitOpError() 2567 << "expected '" << attr << "' to be a symbol reference"; 2568 StringAttr metadataName = accessGroupRef.getRootReference(); 2569 auto metadataOp = 2570 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2571 op->getParentOp(), metadataName); 2572 if (!metadataOp) 2573 return op->emitOpError() 2574 << "expected '" << attr << "' to reference a metadata op"; 2575 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2576 Operation *accessGroupOp = 2577 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2578 if (!accessGroupOp) 2579 return op->emitOpError() 2580 << "expected '" << attr << "' to reference an access_group op"; 2581 } 2582 } 2583 2584 Optional<NamedAttribute> loopOptions = 2585 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2586 if (loopOptions.hasValue() && 2587 !loopOptions->getValue().isa<LoopOptionsAttr>()) 2588 return op->emitOpError() 2589 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2590 << "' to be a `loopopts` attribute"; 2591 } 2592 2593 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2594 return op->emitOpError() 2595 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2596 << "' is permitted only in argument or result attributes"; 2597 } 2598 2599 // If the data layout attribute is present, it must use the LLVM data layout 2600 // syntax. Try parsing it and report errors in case of failure. Users of this 2601 // attribute may assume it is well-formed and can pass it to the (asserting) 2602 // llvm::DataLayout constructor. 2603 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2604 return success(); 2605 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2606 return verifyDataLayoutString( 2607 stringAttr.getValue(), 2608 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2609 2610 return op->emitOpError() << "expected '" 2611 << LLVM::LLVMDialect::getDataLayoutAttrName() 2612 << "' to be a string attribute"; 2613 } 2614 2615 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2616 Type annotatedType) { 2617 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2618 if (!structType) { 2619 const auto emitIncorrectAnnotatedType = [&op]() { 2620 return op->emitError() 2621 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2622 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2623 }; 2624 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2625 if (!ptrType) 2626 return emitIncorrectAnnotatedType(); 2627 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2628 if (!structType) 2629 return emitIncorrectAnnotatedType(); 2630 } 2631 2632 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2633 if (!arrAttrs) 2634 return op->emitError() << "expected '" 2635 << LLVMDialect::getStructAttrsAttrName() 2636 << "' to be an array attribute"; 2637 2638 if (structType.getBody().size() != arrAttrs.size()) 2639 return op->emitError() 2640 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2641 << "' must match the size of the annotated '!llvm.struct'"; 2642 return success(); 2643 } 2644 2645 static LogicalResult verifyFuncOpInterfaceStructAttr( 2646 Operation *op, Attribute attr, 2647 std::function<Type(FunctionOpInterface)> getAnnotatedType) { 2648 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2649 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2650 return op->emitError() << "expected '" 2651 << LLVMDialect::getStructAttrsAttrName() 2652 << "' to be used on function-like operations"; 2653 } 2654 2655 /// Verify LLVMIR function argument attributes. 2656 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2657 unsigned regionIdx, 2658 unsigned argIdx, 2659 NamedAttribute argAttr) { 2660 // Check that llvm.noalias is a unit attribute. 2661 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2662 !argAttr.getValue().isa<UnitAttr>()) 2663 return op->emitError() 2664 << "expected llvm.noalias argument attribute to be a unit attribute"; 2665 // Check that llvm.align is an integer attribute. 2666 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2667 !argAttr.getValue().isa<IntegerAttr>()) 2668 return op->emitError() 2669 << "llvm.align argument attribute of non integer type"; 2670 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2671 return verifyFuncOpInterfaceStructAttr( 2672 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2673 return funcOp.getArgumentTypes()[argIdx]; 2674 }); 2675 } 2676 return success(); 2677 } 2678 2679 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2680 unsigned regionIdx, 2681 unsigned resIdx, 2682 NamedAttribute resAttr) { 2683 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2684 return verifyFuncOpInterfaceStructAttr( 2685 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2686 return funcOp.getResultTypes()[resIdx]; 2687 }); 2688 } 2689 return success(); 2690 } 2691 2692 //===----------------------------------------------------------------------===// 2693 // Utility functions. 2694 //===----------------------------------------------------------------------===// 2695 2696 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2697 StringRef name, StringRef value, 2698 LLVM::Linkage linkage) { 2699 assert(builder.getInsertionBlock() && 2700 builder.getInsertionBlock()->getParentOp() && 2701 "expected builder to point to a block constrained in an op"); 2702 auto module = 2703 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2704 assert(module && "builder points to an op outside of a module"); 2705 2706 // Create the global at the entry of the module. 2707 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2708 MLIRContext *ctx = builder.getContext(); 2709 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2710 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2711 loc, type, /*isConstant=*/true, linkage, name, 2712 builder.getStringAttr(value), /*alignment=*/0); 2713 2714 // Get the pointer to the first character in the global string. 2715 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2716 Value cst0 = builder.create<LLVM::ConstantOp>( 2717 loc, IntegerType::get(ctx, 64), 2718 builder.getIntegerAttr(builder.getIndexType(), 0)); 2719 return builder.create<LLVM::GEPOp>( 2720 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2721 ValueRange{cst0, cst0}); 2722 } 2723 2724 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2725 return op->hasTrait<OpTrait::SymbolTable>() && 2726 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2727 } 2728 2729 static constexpr const FastmathFlags fastmathFlagsList[] = { 2730 // clang-format off 2731 FastmathFlags::nnan, 2732 FastmathFlags::ninf, 2733 FastmathFlags::nsz, 2734 FastmathFlags::arcp, 2735 FastmathFlags::contract, 2736 FastmathFlags::afn, 2737 FastmathFlags::reassoc, 2738 FastmathFlags::fast, 2739 // clang-format on 2740 }; 2741 2742 void FMFAttr::print(AsmPrinter &printer) const { 2743 printer << "<"; 2744 auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { 2745 return bitEnumContains(this->getFlags(), flag); 2746 }); 2747 llvm::interleaveComma(flags, printer, 2748 [&](auto flag) { printer << stringifyEnum(flag); }); 2749 printer << ">"; 2750 } 2751 2752 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2753 if (failed(parser.parseLess())) 2754 return {}; 2755 2756 FastmathFlags flags = {}; 2757 if (failed(parser.parseOptionalGreater())) { 2758 do { 2759 StringRef elemName; 2760 if (failed(parser.parseKeyword(&elemName))) 2761 return {}; 2762 2763 auto elem = symbolizeFastmathFlags(elemName); 2764 if (!elem) { 2765 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2766 << elemName; 2767 return {}; 2768 } 2769 2770 flags = flags | *elem; 2771 } while (succeeded(parser.parseOptionalComma())); 2772 2773 if (failed(parser.parseGreater())) 2774 return {}; 2775 } 2776 2777 return FMFAttr::get(parser.getContext(), flags); 2778 } 2779 2780 void LinkageAttr::print(AsmPrinter &printer) const { 2781 printer << "<"; 2782 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2783 printer << stringifyEnum(getLinkage()); 2784 else 2785 printer << static_cast<uint64_t>(getLinkage()); 2786 printer << ">"; 2787 } 2788 2789 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2790 StringRef elemName; 2791 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2792 parser.parseGreater()) 2793 return {}; 2794 auto elem = linkage::symbolizeLinkage(elemName); 2795 if (!elem) { 2796 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2797 return {}; 2798 } 2799 Linkage linkage = *elem; 2800 return LinkageAttr::get(parser.getContext(), linkage); 2801 } 2802 2803 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2804 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2805 2806 template <typename T> 2807 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2808 Optional<T> value) { 2809 auto option = llvm::find_if( 2810 options, [tag](auto option) { return option.first == tag; }); 2811 if (option != options.end()) { 2812 if (value.hasValue()) 2813 option->second = *value; 2814 else 2815 options.erase(option); 2816 } else { 2817 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2818 } 2819 return *this; 2820 } 2821 2822 LoopOptionsAttrBuilder & 2823 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 2824 return setOption(LoopOptionCase::disable_licm, value); 2825 } 2826 2827 /// Set the `interleave_count` option to the provided value. If no value 2828 /// is provided the option is deleted. 2829 LoopOptionsAttrBuilder & 2830 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 2831 return setOption(LoopOptionCase::interleave_count, count); 2832 } 2833 2834 /// Set the `disable_unroll` option to the provided value. If no value 2835 /// is provided the option is deleted. 2836 LoopOptionsAttrBuilder & 2837 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 2838 return setOption(LoopOptionCase::disable_unroll, value); 2839 } 2840 2841 /// Set the `disable_pipeline` option to the provided value. If no value 2842 /// is provided the option is deleted. 2843 LoopOptionsAttrBuilder & 2844 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 2845 return setOption(LoopOptionCase::disable_pipeline, value); 2846 } 2847 2848 /// Set the `pipeline_initiation_interval` option to the provided value. 2849 /// If no value is provided the option is deleted. 2850 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 2851 Optional<uint64_t> count) { 2852 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 2853 } 2854 2855 template <typename T> 2856 static Optional<T> 2857 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 2858 LoopOptionCase option) { 2859 auto it = 2860 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 2861 return optionPair.first < option; 2862 }); 2863 if (it == options.end()) 2864 return {}; 2865 return static_cast<T>(it->second); 2866 } 2867 2868 Optional<bool> LoopOptionsAttr::disableUnroll() { 2869 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 2870 } 2871 2872 Optional<bool> LoopOptionsAttr::disableLICM() { 2873 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 2874 } 2875 2876 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 2877 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 2878 } 2879 2880 /// Build the LoopOptions Attribute from a sorted array of individual options. 2881 LoopOptionsAttr LoopOptionsAttr::get( 2882 MLIRContext *context, 2883 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 2884 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 2885 "LoopOptionsAttr ctor expects a sorted options array"); 2886 return Base::get(context, sortedOptions); 2887 } 2888 2889 /// Build the LoopOptions Attribute from a sorted array of individual options. 2890 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 2891 LoopOptionsAttrBuilder &optionBuilders) { 2892 llvm::sort(optionBuilders.options, llvm::less_first()); 2893 return Base::get(context, optionBuilders.options); 2894 } 2895 2896 void LoopOptionsAttr::print(AsmPrinter &printer) const { 2897 printer << "<"; 2898 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 2899 printer << stringifyEnum(option.first) << " = "; 2900 switch (option.first) { 2901 case LoopOptionCase::disable_licm: 2902 case LoopOptionCase::disable_unroll: 2903 case LoopOptionCase::disable_pipeline: 2904 printer << (option.second ? "true" : "false"); 2905 break; 2906 case LoopOptionCase::interleave_count: 2907 case LoopOptionCase::pipeline_initiation_interval: 2908 printer << option.second; 2909 break; 2910 } 2911 }); 2912 printer << ">"; 2913 } 2914 2915 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 2916 if (failed(parser.parseLess())) 2917 return {}; 2918 2919 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 2920 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 2921 do { 2922 StringRef optionName; 2923 if (parser.parseKeyword(&optionName)) 2924 return {}; 2925 2926 auto option = symbolizeLoopOptionCase(optionName); 2927 if (!option) { 2928 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 2929 << optionName; 2930 return {}; 2931 } 2932 if (!seenOptions.insert(*option).second) { 2933 parser.emitError(parser.getNameLoc(), "loop option present twice"); 2934 return {}; 2935 } 2936 if (failed(parser.parseEqual())) 2937 return {}; 2938 2939 int64_t value; 2940 switch (*option) { 2941 case LoopOptionCase::disable_licm: 2942 case LoopOptionCase::disable_unroll: 2943 case LoopOptionCase::disable_pipeline: 2944 if (succeeded(parser.parseOptionalKeyword("true"))) 2945 value = 1; 2946 else if (succeeded(parser.parseOptionalKeyword("false"))) 2947 value = 0; 2948 else { 2949 parser.emitError(parser.getNameLoc(), 2950 "expected boolean value 'true' or 'false'"); 2951 return {}; 2952 } 2953 break; 2954 case LoopOptionCase::interleave_count: 2955 case LoopOptionCase::pipeline_initiation_interval: 2956 if (failed(parser.parseInteger(value))) { 2957 parser.emitError(parser.getNameLoc(), "expected integer value"); 2958 return {}; 2959 } 2960 break; 2961 } 2962 options.push_back(std::make_pair(*option, value)); 2963 } while (succeeded(parser.parseOptionalComma())); 2964 if (failed(parser.parseGreater())) 2965 return {}; 2966 2967 llvm::sort(options, llvm::less_first()); 2968 return get(parser.getContext(), options); 2969 } 2970