1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Mutex.h" 33 #include "llvm/Support/SourceMgr.h" 34 35 #include <iostream> 36 #include <numeric> 37 38 using namespace mlir; 39 using namespace mlir::LLVM; 40 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 41 42 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 43 44 static constexpr const char kVolatileAttrName[] = "volatile_"; 45 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 46 47 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 48 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 49 #define GET_ATTRDEF_CLASSES 50 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 51 52 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 53 SmallVector<NamedAttribute, 8> filteredAttrs( 54 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 55 if (attr.getName() == "fastmathFlags") { 56 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 57 return defAttr != attr.getValue(); 58 } 59 return true; 60 })); 61 return filteredAttrs; 62 } 63 64 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 65 NamedAttrList &result) { 66 return parser.parseOptionalAttrDict(result); 67 } 68 69 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 70 DictionaryAttr attrs) { 71 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 72 } 73 74 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 75 /// fully defined llvm.func. 76 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 77 Operation *op, 78 SymbolTableCollection &symbolTable) { 79 StringRef name = symbol.getValue(); 80 auto func = 81 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 82 if (!func) 83 return op->emitOpError("'") 84 << name << "' does not reference a valid LLVM function"; 85 if (func.isExternal()) 86 return op->emitOpError("'") << name << "' does not have a definition"; 87 return success(); 88 } 89 90 //===----------------------------------------------------------------------===// 91 // Printing/parsing for LLVM::CmpOp. 92 //===----------------------------------------------------------------------===// 93 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) { 94 p << " \"" << stringifyICmpPredicate(op.getPredicate()) << "\" " 95 << op.getOperand(0) << ", " << op.getOperand(1); 96 p.printOptionalAttrDict(op->getAttrs(), {"predicate"}); 97 p << " : " << op.getLhs().getType(); 98 } 99 100 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) { 101 p << " \"" << stringifyFCmpPredicate(op.getPredicate()) << "\" " 102 << op.getOperand(0) << ", " << op.getOperand(1); 103 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"}); 104 p << " : " << op.getLhs().getType(); 105 } 106 107 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 108 // attribute-dict? `:` type 109 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 110 // attribute-dict? `:` type 111 template <typename CmpPredicateType> 112 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 113 Builder &builder = parser.getBuilder(); 114 115 StringAttr predicateAttr; 116 OpAsmParser::OperandType lhs, rhs; 117 Type type; 118 llvm::SMLoc predicateLoc, trailingTypeLoc; 119 if (parser.getCurrentLocation(&predicateLoc) || 120 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 121 parser.parseOperand(lhs) || parser.parseComma() || 122 parser.parseOperand(rhs) || 123 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 124 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 125 parser.resolveOperand(lhs, type, result.operands) || 126 parser.resolveOperand(rhs, type, result.operands)) 127 return failure(); 128 129 // Replace the string attribute `predicate` with an integer attribute. 130 int64_t predicateValue = 0; 131 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 132 Optional<ICmpPredicate> predicate = 133 symbolizeICmpPredicate(predicateAttr.getValue()); 134 if (!predicate) 135 return parser.emitError(predicateLoc) 136 << "'" << predicateAttr.getValue() 137 << "' is an incorrect value of the 'predicate' attribute"; 138 predicateValue = static_cast<int64_t>(predicate.getValue()); 139 } else { 140 Optional<FCmpPredicate> predicate = 141 symbolizeFCmpPredicate(predicateAttr.getValue()); 142 if (!predicate) 143 return parser.emitError(predicateLoc) 144 << "'" << predicateAttr.getValue() 145 << "' is an incorrect value of the 'predicate' attribute"; 146 predicateValue = static_cast<int64_t>(predicate.getValue()); 147 } 148 149 result.attributes.set("predicate", 150 parser.getBuilder().getI64IntegerAttr(predicateValue)); 151 152 // The result type is either i1 or a vector type <? x i1> if the inputs are 153 // vectors. 154 Type resultType = IntegerType::get(builder.getContext(), 1); 155 if (!isCompatibleType(type)) 156 return parser.emitError(trailingTypeLoc, 157 "expected LLVM dialect-compatible type"); 158 if (LLVM::isCompatibleVectorType(type)) { 159 if (LLVM::isScalableVectorType(type)) { 160 resultType = LLVM::getVectorType( 161 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 162 /*isScalable=*/true); 163 } else { 164 resultType = LLVM::getVectorType( 165 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 166 /*isScalable=*/false); 167 } 168 } 169 170 result.addTypes({resultType}); 171 return success(); 172 } 173 174 //===----------------------------------------------------------------------===// 175 // Printing/parsing for LLVM::AllocaOp. 176 //===----------------------------------------------------------------------===// 177 178 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { 179 auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType(); 180 181 auto funcTy = FunctionType::get( 182 op.getContext(), {op.getArraySize().getType()}, {op.getType()}); 183 184 p << ' ' << op.getArraySize() << " x " << elemTy; 185 if (op.getAlignment().hasValue() && *op.getAlignment() != 0) 186 p.printOptionalAttrDict(op->getAttrs()); 187 else 188 p.printOptionalAttrDict(op->getAttrs(), {"alignment"}); 189 p << " : " << funcTy; 190 } 191 192 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 193 // `:` type `,` type 194 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { 195 OpAsmParser::OperandType arraySize; 196 Type type, elemType; 197 llvm::SMLoc trailingTypeLoc; 198 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 199 parser.parseType(elemType) || 200 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 201 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 202 return failure(); 203 204 Optional<NamedAttribute> alignmentAttr = 205 result.attributes.getNamed("alignment"); 206 if (alignmentAttr.hasValue()) { 207 auto alignmentInt = 208 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 209 if (!alignmentInt) 210 return parser.emitError(parser.getNameLoc(), 211 "expected integer alignment"); 212 if (alignmentInt.getValue().isNullValue()) 213 result.attributes.erase("alignment"); 214 } 215 216 // Extract the result type from the trailing function type. 217 auto funcType = type.dyn_cast<FunctionType>(); 218 if (!funcType || funcType.getNumInputs() != 1 || 219 funcType.getNumResults() != 1) 220 return parser.emitError( 221 trailingTypeLoc, 222 "expected trailing function type with one argument and one result"); 223 224 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 225 return failure(); 226 227 result.addTypes({funcType.getResult(0)}); 228 return success(); 229 } 230 231 //===----------------------------------------------------------------------===// 232 // LLVM::BrOp 233 //===----------------------------------------------------------------------===// 234 235 Optional<MutableOperandRange> 236 BrOp::getMutableSuccessorOperands(unsigned index) { 237 assert(index == 0 && "invalid successor index"); 238 return getDestOperandsMutable(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // LLVM::CondBrOp 243 //===----------------------------------------------------------------------===// 244 245 Optional<MutableOperandRange> 246 CondBrOp::getMutableSuccessorOperands(unsigned index) { 247 assert(index < getNumSuccessors() && "invalid successor index"); 248 return index == 0 ? getTrueDestOperandsMutable() 249 : getFalseDestOperandsMutable(); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // LLVM::SwitchOp 254 //===----------------------------------------------------------------------===// 255 256 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 257 Block *defaultDestination, ValueRange defaultOperands, 258 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 259 ArrayRef<ValueRange> caseOperands, 260 ArrayRef<int32_t> branchWeights) { 261 ElementsAttr caseValuesAttr; 262 if (!caseValues.empty()) 263 caseValuesAttr = builder.getI32VectorAttr(caseValues); 264 265 ElementsAttr weightsAttr; 266 if (!branchWeights.empty()) 267 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 268 269 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 270 weightsAttr, defaultDestination, caseDestinations); 271 } 272 273 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 274 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 275 static ParseResult parseSwitchOpCases( 276 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 277 SmallVectorImpl<Block *> &caseDestinations, 278 SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands, 279 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 280 SmallVector<APInt> values; 281 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 282 do { 283 int64_t value = 0; 284 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 285 if (values.empty() && !integerParseResult.hasValue()) 286 return success(); 287 288 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 289 return failure(); 290 values.push_back(APInt(bitWidth, value)); 291 292 Block *destination; 293 SmallVector<OpAsmParser::OperandType> operands; 294 SmallVector<Type> operandTypes; 295 if (parser.parseColon() || parser.parseSuccessor(destination)) 296 return failure(); 297 if (!parser.parseOptionalLParen()) { 298 if (parser.parseRegionArgumentList(operands) || 299 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 300 return failure(); 301 } 302 caseDestinations.push_back(destination); 303 caseOperands.emplace_back(operands); 304 caseOperandTypes.emplace_back(operandTypes); 305 } while (!parser.parseOptionalComma()); 306 307 ShapedType caseValueType = 308 VectorType::get(static_cast<int64_t>(values.size()), flagType); 309 caseValues = DenseIntElementsAttr::get(caseValueType, values); 310 return success(); 311 } 312 313 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 314 ElementsAttr caseValues, 315 SuccessorRange caseDestinations, 316 OperandRangeRange caseOperands, 317 const TypeRangeRange &caseOperandTypes) { 318 if (!caseValues) 319 return; 320 321 size_t index = 0; 322 llvm::interleave( 323 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 324 [&](auto i) { 325 p << " "; 326 p << std::get<0>(i).getLimitedValue(); 327 p << ": "; 328 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 329 }, 330 [&] { 331 p << ','; 332 p.printNewline(); 333 }); 334 p.printNewline(); 335 } 336 337 static LogicalResult verify(SwitchOp op) { 338 if ((!op.getCaseValues() && !op.getCaseDestinations().empty()) || 339 (op.getCaseValues() && 340 op.getCaseValues()->size() != 341 static_cast<int64_t>(op.getCaseDestinations().size()))) 342 return op.emitOpError("expects number of case values to match number of " 343 "case destinations"); 344 if (op.getBranchWeights() && 345 op.getBranchWeights()->size() != op.getNumSuccessors()) 346 return op.emitError("expects number of branch weights to match number of " 347 "successors: ") 348 << op.getBranchWeights()->size() << " vs " << op.getNumSuccessors(); 349 return success(); 350 } 351 352 Optional<MutableOperandRange> 353 SwitchOp::getMutableSuccessorOperands(unsigned index) { 354 assert(index < getNumSuccessors() && "invalid successor index"); 355 return index == 0 ? getDefaultOperandsMutable() 356 : getCaseOperandsMutable(index - 1); 357 } 358 359 //===----------------------------------------------------------------------===// 360 // Code for LLVM::GEPOp. 361 //===----------------------------------------------------------------------===// 362 363 constexpr int GEPOp::kDynamicIndex; 364 365 /// Populates `indices` with positions of GEP indices that would correspond to 366 /// LLVMStructTypes potentially nested in the given type. The type currently 367 /// visited gets `currentIndex` and LLVM container types are visited 368 /// recursively. The recursion is bounded and takes care of recursive types by 369 /// means of the `visited` set. 370 static void recordStructIndices(Type type, unsigned currentIndex, 371 SmallVectorImpl<unsigned> &indices, 372 SmallVectorImpl<unsigned> *structSizes, 373 SmallPtrSet<Type, 4> &visited) { 374 if (visited.contains(type)) 375 return; 376 377 visited.insert(type); 378 379 llvm::TypeSwitch<Type>(type) 380 .Case<LLVMStructType>([&](LLVMStructType structType) { 381 indices.push_back(currentIndex); 382 if (structSizes) 383 structSizes->push_back(structType.getBody().size()); 384 for (Type elementType : structType.getBody()) 385 recordStructIndices(elementType, currentIndex + 1, indices, 386 structSizes, visited); 387 }) 388 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 389 LLVMArrayType>([&](auto containerType) { 390 recordStructIndices(containerType.getElementType(), currentIndex + 1, 391 indices, structSizes, visited); 392 }); 393 } 394 395 /// Populates `indices` with positions of GEP indices that correspond to 396 /// LLVMStructTypes potentially nested in the given `baseGEPType`, which must 397 /// be either an LLVMPointer type or a vector thereof. If `structSizes` is 398 /// provided, it is populated with sizes of the indexed structs for bounds 399 /// verification purposes. 400 static void 401 findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices, 402 SmallVectorImpl<unsigned> *structSizes = nullptr) { 403 Type type = baseGEPType; 404 if (auto vectorType = type.dyn_cast<VectorType>()) 405 type = vectorType.getElementType(); 406 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 407 type = scalableVectorType.getElementType(); 408 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 409 type = fixedVectorType.getElementType(); 410 411 Type pointeeType = type.cast<LLVMPointerType>().getElementType(); 412 SmallPtrSet<Type, 4> visited; 413 recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes, 414 visited); 415 } 416 417 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 418 Value basePtr, ValueRange operands, 419 ArrayRef<NamedAttribute> attributes) { 420 build(builder, result, resultType, basePtr, operands, 421 SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex), 422 attributes); 423 } 424 425 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 426 Value basePtr, ValueRange indices, 427 ArrayRef<int32_t> structIndices, 428 ArrayRef<NamedAttribute> attributes) { 429 SmallVector<Value> remainingIndices; 430 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 431 structIndices.end()); 432 SmallVector<unsigned> structRelatedPositions; 433 findKnownStructIndices(basePtr.getType(), structRelatedPositions); 434 435 SmallVector<unsigned> operandsToErase; 436 for (unsigned pos : structRelatedPositions) { 437 // GEP may not be indexing as deep as some structs are located. 438 if (pos >= structIndices.size()) 439 continue; 440 441 // If the index is already static, it's fine. 442 if (structIndices[pos] != kDynamicIndex) 443 continue; 444 445 // Find the corresponding operand. 446 unsigned operandPos = 447 std::count(structIndices.begin(), std::next(structIndices.begin(), pos), 448 kDynamicIndex); 449 450 // Extract the constant value from the operand and put it into the attribute 451 // instead. 452 APInt staticIndexValue; 453 bool matched = 454 matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue)); 455 (void)matched; 456 assert(matched && "index into a struct must be a constant"); 457 assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) && 458 "struct index underflows 32-bit integer"); 459 assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) && 460 "struct index overflows 32-bit integer"); 461 auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 462 updatedStructIndices[pos] = staticIndex; 463 operandsToErase.push_back(operandPos); 464 } 465 466 for (unsigned i = 0, e = indices.size(); i < e; ++i) { 467 if (!llvm::is_contained(operandsToErase, i)) 468 remainingIndices.push_back(indices[i]); 469 } 470 471 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 472 updatedStructIndices, kDynamicIndex)) && 473 "expected as many index operands as dynamic index attr elements"); 474 475 result.addTypes(resultType); 476 result.addAttributes(attributes); 477 result.addAttribute("structIndices", 478 builder.getI32TensorAttr(updatedStructIndices)); 479 result.addOperands(basePtr); 480 result.addOperands(remainingIndices); 481 } 482 483 static ParseResult 484 parseGEPIndices(OpAsmParser &parser, 485 SmallVectorImpl<OpAsmParser::OperandType> &indices, 486 DenseIntElementsAttr &structIndices) { 487 SmallVector<int32_t> constantIndices; 488 do { 489 int32_t constantIndex; 490 OptionalParseResult parsedInteger = 491 parser.parseOptionalInteger(constantIndex); 492 if (parsedInteger.hasValue()) { 493 if (failed(parsedInteger.getValue())) 494 return failure(); 495 constantIndices.push_back(constantIndex); 496 continue; 497 } 498 499 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 500 if (failed(parser.parseOperand(indices.emplace_back()))) 501 return failure(); 502 } while (succeeded(parser.parseOptionalComma())); 503 504 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 505 return success(); 506 } 507 508 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 509 OperandRange indices, 510 DenseIntElementsAttr structIndices) { 511 unsigned operandIdx = 0; 512 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 513 [&](int32_t cst) { 514 if (cst == LLVM::GEPOp::kDynamicIndex) 515 printer.printOperand(indices[operandIdx++]); 516 else 517 printer << cst; 518 }); 519 } 520 521 LogicalResult verify(LLVM::GEPOp gepOp) { 522 SmallVector<unsigned> indices; 523 SmallVector<unsigned> structSizes; 524 findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes); 525 DenseIntElementsAttr structIndices = gepOp.getStructIndices(); 526 for (unsigned i : llvm::seq<unsigned>(0, indices.size())) { 527 unsigned index = indices[i]; 528 // GEP may not be indexing as deep as some structs nested in the type. 529 if (index >= structIndices.getNumElements()) 530 continue; 531 532 int32_t staticIndex = structIndices.getValues<int32_t>()[index]; 533 if (staticIndex == LLVM::GEPOp::kDynamicIndex) 534 return gepOp.emitOpError() << "expected index " << index 535 << " indexing a struct to be constant"; 536 if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i]) 537 return gepOp.emitOpError() 538 << "index " << index << " indexing a struct is out of bounds"; 539 } 540 return success(); 541 } 542 543 //===----------------------------------------------------------------------===// 544 // Builder, printer and parser for for LLVM::LoadOp. 545 //===----------------------------------------------------------------------===// 546 547 LogicalResult verifySymbolAttribute( 548 Operation *op, StringRef attributeName, 549 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 550 verifySymbolType) { 551 if (Attribute attribute = op->getAttr(attributeName)) { 552 // The attribute is already verified to be a symbol ref array attribute via 553 // a constraint in the operation definition. 554 for (SymbolRefAttr symbolRef : 555 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 556 StringAttr metadataName = symbolRef.getRootReference(); 557 StringAttr symbolName = symbolRef.getLeafReference(); 558 // We want @metadata::@symbol, not just @symbol 559 if (metadataName == symbolName) { 560 return op->emitOpError() << "expected '" << symbolRef 561 << "' to specify a fully qualified reference"; 562 } 563 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 564 op->getParentOp(), metadataName); 565 if (!metadataOp) 566 return op->emitOpError() 567 << "expected '" << symbolRef << "' to reference a metadata op"; 568 Operation *symbolOp = 569 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 570 if (!symbolOp) 571 return op->emitOpError() 572 << "expected '" << symbolRef << "' to be a valid reference"; 573 if (failed(verifySymbolType(symbolOp, symbolRef))) { 574 return failure(); 575 } 576 } 577 } 578 return success(); 579 } 580 581 // Verifies that metadata ops are wired up properly. 582 template <typename OpTy> 583 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 584 auto verifySymbolType = [op](Operation *symbolOp, 585 SymbolRefAttr symbolRef) -> LogicalResult { 586 if (!isa<OpTy>(symbolOp)) { 587 return op->emitOpError() 588 << "expected '" << symbolRef << "' to resolve to a " 589 << OpTy::getOperationName(); 590 } 591 return success(); 592 }; 593 594 return verifySymbolAttribute(op, attributeName, verifySymbolType); 595 } 596 597 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 598 // access_groups 599 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 600 op, LLVMDialect::getAccessGroupsAttrName()))) 601 return failure(); 602 603 // alias_scopes 604 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 605 op, LLVMDialect::getAliasScopesAttrName()))) 606 return failure(); 607 608 // noalias_scopes 609 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 610 op, LLVMDialect::getNoAliasScopesAttrName()))) 611 return failure(); 612 613 return success(); 614 } 615 616 static LogicalResult verify(LoadOp op) { 617 return verifyMemoryOpMetadata(op.getOperation()); 618 } 619 620 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 621 Value addr, unsigned alignment, bool isVolatile, 622 bool isNonTemporal) { 623 result.addOperands(addr); 624 result.addTypes(t); 625 if (isVolatile) 626 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 627 if (isNonTemporal) 628 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 629 if (alignment != 0) 630 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 631 } 632 633 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { 634 p << ' '; 635 if (op.getVolatile_()) 636 p << "volatile "; 637 p << op.getAddr(); 638 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 639 p << " : " << op.getAddr().getType(); 640 } 641 642 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 643 // the resulting type wrapped in MLIR, or nullptr on error. 644 static Type getLoadStoreElementType(OpAsmParser &parser, Type type, 645 llvm::SMLoc trailingTypeLoc) { 646 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 647 if (!llvmTy) 648 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), 649 nullptr; 650 return llvmTy.getElementType(); 651 } 652 653 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 654 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { 655 OpAsmParser::OperandType addr; 656 Type type; 657 llvm::SMLoc trailingTypeLoc; 658 659 if (succeeded(parser.parseOptionalKeyword("volatile"))) 660 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 661 662 if (parser.parseOperand(addr) || 663 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 664 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 665 parser.resolveOperand(addr, type, result.operands)) 666 return failure(); 667 668 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 669 670 result.addTypes(elemTy); 671 return success(); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // Builder, printer and parser for LLVM::StoreOp. 676 //===----------------------------------------------------------------------===// 677 678 static LogicalResult verify(StoreOp op) { 679 return verifyMemoryOpMetadata(op.getOperation()); 680 } 681 682 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 683 Value addr, unsigned alignment, bool isVolatile, 684 bool isNonTemporal) { 685 result.addOperands({value, addr}); 686 result.addTypes({}); 687 if (isVolatile) 688 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 689 if (isNonTemporal) 690 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 691 if (alignment != 0) 692 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 693 } 694 695 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { 696 p << ' '; 697 if (op.getVolatile_()) 698 p << "volatile "; 699 p << op.getValue() << ", " << op.getAddr(); 700 p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName}); 701 p << " : " << op.getAddr().getType(); 702 } 703 704 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 705 // attribute-dict? `:` type 706 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { 707 OpAsmParser::OperandType addr, value; 708 Type type; 709 llvm::SMLoc trailingTypeLoc; 710 711 if (succeeded(parser.parseOptionalKeyword("volatile"))) 712 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 713 714 if (parser.parseOperand(value) || parser.parseComma() || 715 parser.parseOperand(addr) || 716 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 717 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 718 return failure(); 719 720 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 721 if (!elemTy) 722 return failure(); 723 724 if (parser.resolveOperand(value, elemTy, result.operands) || 725 parser.resolveOperand(addr, type, result.operands)) 726 return failure(); 727 728 return success(); 729 } 730 731 ///===---------------------------------------------------------------------===// 732 /// LLVM::InvokeOp 733 ///===---------------------------------------------------------------------===// 734 735 Optional<MutableOperandRange> 736 InvokeOp::getMutableSuccessorOperands(unsigned index) { 737 assert(index < getNumSuccessors() && "invalid successor index"); 738 return index == 0 ? getNormalDestOperandsMutable() 739 : getUnwindDestOperandsMutable(); 740 } 741 742 static LogicalResult verify(InvokeOp op) { 743 if (op.getNumResults() > 1) 744 return op.emitOpError("must have 0 or 1 result"); 745 746 Block *unwindDest = op.getUnwindDest(); 747 if (unwindDest->empty()) 748 return op.emitError( 749 "must have at least one operation in unwind destination"); 750 751 // In unwind destination, first operation must be LandingpadOp 752 if (!isa<LandingpadOp>(unwindDest->front())) 753 return op.emitError("first operation in unwind destination should be a " 754 "llvm.landingpad operation"); 755 756 return success(); 757 } 758 759 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) { 760 auto callee = op.getCallee(); 761 bool isDirect = callee.hasValue(); 762 763 p << ' '; 764 765 // Either function name or pointer 766 if (isDirect) 767 p.printSymbolName(callee.getValue()); 768 else 769 p << op.getOperand(0); 770 771 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 772 p << " to "; 773 p.printSuccessorAndUseList(op.getNormalDest(), op.getNormalDestOperands()); 774 p << " unwind "; 775 p.printSuccessorAndUseList(op.getUnwindDest(), op.getUnwindDestOperands()); 776 777 p.printOptionalAttrDict(op->getAttrs(), 778 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 779 p << " : "; 780 p.printFunctionalType( 781 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1), 782 op.getResultTypes()); 783 } 784 785 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 786 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 787 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 788 /// attribute-dict? `:` function-type 789 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { 790 SmallVector<OpAsmParser::OperandType, 8> operands; 791 FunctionType funcType; 792 SymbolRefAttr funcAttr; 793 llvm::SMLoc trailingTypeLoc; 794 Block *normalDest, *unwindDest; 795 SmallVector<Value, 4> normalOperands, unwindOperands; 796 Builder &builder = parser.getBuilder(); 797 798 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 799 // case of an indirect call, there will be 1 operand before `(`. In case of a 800 // direct call, there will be no operands and the parser will stop at the 801 // function identifier without complaining. 802 if (parser.parseOperandList(operands)) 803 return failure(); 804 bool isDirect = operands.empty(); 805 806 // Optionally parse a function identifier. 807 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 808 return failure(); 809 810 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 811 parser.parseKeyword("to") || 812 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 813 parser.parseKeyword("unwind") || 814 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 815 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 816 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 817 return failure(); 818 819 if (isDirect) { 820 // Make sure types match. 821 if (parser.resolveOperands(operands, funcType.getInputs(), 822 parser.getNameLoc(), result.operands)) 823 return failure(); 824 result.addTypes(funcType.getResults()); 825 } else { 826 // Construct the LLVM IR Dialect function type that the first operand 827 // should match. 828 if (funcType.getNumResults() > 1) 829 return parser.emitError(trailingTypeLoc, 830 "expected function with 0 or 1 result"); 831 832 Type llvmResultType; 833 if (funcType.getNumResults() == 0) { 834 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 835 } else { 836 llvmResultType = funcType.getResult(0); 837 if (!isCompatibleType(llvmResultType)) 838 return parser.emitError(trailingTypeLoc, 839 "expected result to have LLVM type"); 840 } 841 842 SmallVector<Type, 8> argTypes; 843 argTypes.reserve(funcType.getNumInputs()); 844 for (Type ty : funcType.getInputs()) { 845 if (isCompatibleType(ty)) 846 argTypes.push_back(ty); 847 else 848 return parser.emitError(trailingTypeLoc, 849 "expected LLVM types as inputs"); 850 } 851 852 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 853 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 854 855 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 856 857 // Make sure that the first operand (indirect callee) matches the wrapped 858 // LLVM IR function type, and that the types of the other call operands 859 // match the types of the function arguments. 860 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 861 parser.resolveOperands(funcArguments, funcType.getInputs(), 862 parser.getNameLoc(), result.operands)) 863 return failure(); 864 865 result.addTypes(llvmResultType); 866 } 867 result.addSuccessors({normalDest, unwindDest}); 868 result.addOperands(normalOperands); 869 result.addOperands(unwindOperands); 870 871 result.addAttribute( 872 InvokeOp::getOperandSegmentSizeAttr(), 873 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 874 static_cast<int32_t>(normalOperands.size()), 875 static_cast<int32_t>(unwindOperands.size())})); 876 return success(); 877 } 878 879 ///===----------------------------------------------------------------------===// 880 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 881 ///===----------------------------------------------------------------------===// 882 883 static LogicalResult verify(LandingpadOp op) { 884 Value value; 885 if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) { 886 if (!func.getPersonality().hasValue()) 887 return op.emitError( 888 "llvm.landingpad needs to be in a function with a personality"); 889 } 890 891 if (!op.getCleanup() && op.getOperands().empty()) 892 return op.emitError("landingpad instruction expects at least one clause or " 893 "cleanup attribute"); 894 895 for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { 896 value = op.getOperand(idx); 897 bool isFilter = value.getType().isa<LLVMArrayType>(); 898 if (isFilter) { 899 // FIXME: Verify filter clauses when arrays are appropriately handled 900 } else { 901 // catch - global addresses only. 902 // Bitcast ops should have global addresses as their args. 903 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 904 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 905 continue; 906 return op.emitError("constant clauses expected") 907 .attachNote(bcOp.getLoc()) 908 << "global addresses expected as operand to " 909 "bitcast used in clauses for landingpad"; 910 } 911 // NullOp and AddressOfOp allowed 912 if (value.getDefiningOp<NullOp>()) 913 continue; 914 if (value.getDefiningOp<AddressOfOp>()) 915 continue; 916 return op.emitError("clause #") 917 << idx << " is not a known constant - null, addressof, bitcast"; 918 } 919 } 920 return success(); 921 } 922 923 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) { 924 p << (op.getCleanup() ? " cleanup " : " "); 925 926 // Clauses 927 for (auto value : op.getOperands()) { 928 // Similar to llvm - if clause is an array type then it is filter 929 // clause else catch clause 930 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 931 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 932 << value.getType() << ") "; 933 } 934 935 p.printOptionalAttrDict(op->getAttrs(), {"cleanup"}); 936 937 p << ": " << op.getType(); 938 } 939 940 /// <operation> ::= `llvm.landingpad` `cleanup`? 941 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 942 static ParseResult parseLandingpadOp(OpAsmParser &parser, 943 OperationState &result) { 944 // Check for cleanup 945 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 946 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 947 948 // Parse clauses with types 949 while (succeeded(parser.parseOptionalLParen()) && 950 (succeeded(parser.parseOptionalKeyword("filter")) || 951 succeeded(parser.parseOptionalKeyword("catch")))) { 952 OpAsmParser::OperandType operand; 953 Type ty; 954 if (parser.parseOperand(operand) || parser.parseColon() || 955 parser.parseType(ty) || 956 parser.resolveOperand(operand, ty, result.operands) || 957 parser.parseRParen()) 958 return failure(); 959 } 960 961 Type type; 962 if (parser.parseColon() || parser.parseType(type)) 963 return failure(); 964 965 result.addTypes(type); 966 return success(); 967 } 968 969 //===----------------------------------------------------------------------===// 970 // Verifying/Printing/parsing for LLVM::CallOp. 971 //===----------------------------------------------------------------------===// 972 973 static LogicalResult verify(CallOp &op) { 974 if (op.getNumResults() > 1) 975 return op.emitOpError("must have 0 or 1 result"); 976 977 // Type for the callee, we'll get it differently depending if it is a direct 978 // or indirect call. 979 Type fnType; 980 981 bool isIndirect = false; 982 983 // If this is an indirect call, the callee attribute is missing. 984 FlatSymbolRefAttr calleeName = op.getCalleeAttr(); 985 if (!calleeName) { 986 isIndirect = true; 987 if (!op.getNumOperands()) 988 return op.emitOpError( 989 "must have either a `callee` attribute or at least an operand"); 990 auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>(); 991 if (!ptrType) 992 return op.emitOpError("indirect call expects a pointer as callee: ") 993 << ptrType; 994 fnType = ptrType.getElementType(); 995 } else { 996 Operation *callee = 997 SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr()); 998 if (!callee) 999 return op.emitOpError() 1000 << "'" << calleeName.getValue() 1001 << "' does not reference a symbol in the current scope"; 1002 auto fn = dyn_cast<LLVMFuncOp>(callee); 1003 if (!fn) 1004 return op.emitOpError() << "'" << calleeName.getValue() 1005 << "' does not reference a valid LLVM function"; 1006 1007 fnType = fn.getType(); 1008 } 1009 1010 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1011 if (!funcType) 1012 return op.emitOpError("callee does not have a functional type: ") << fnType; 1013 1014 // Verify that the operand and result types match the callee. 1015 1016 if (!funcType.isVarArg() && 1017 funcType.getNumParams() != (op.getNumOperands() - isIndirect)) 1018 return op.emitOpError() 1019 << "incorrect number of operands (" 1020 << (op.getNumOperands() - isIndirect) 1021 << ") for callee (expecting: " << funcType.getNumParams() << ")"; 1022 1023 if (funcType.getNumParams() > (op.getNumOperands() - isIndirect)) 1024 return op.emitOpError() << "incorrect number of operands (" 1025 << (op.getNumOperands() - isIndirect) 1026 << ") for varargs callee (expecting at least: " 1027 << funcType.getNumParams() << ")"; 1028 1029 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1030 if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1031 return op.emitOpError() << "operand type mismatch for operand " << i 1032 << ": " << op.getOperand(i + isIndirect).getType() 1033 << " != " << funcType.getParamType(i); 1034 1035 if (op.getNumResults() == 0 && 1036 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1037 return op.emitOpError() << "expected function call to produce a value"; 1038 1039 if (op.getNumResults() != 0 && 1040 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1041 return op.emitOpError() 1042 << "calling function with void result must not produce values"; 1043 1044 if (op.getNumResults() > 1) 1045 return op.emitOpError() 1046 << "expected LLVM function call to produce 0 or 1 result"; 1047 1048 if (op.getNumResults() && 1049 op.getResult(0).getType() != funcType.getReturnType()) 1050 return op.emitOpError() 1051 << "result type mismatch: " << op.getResult(0).getType() 1052 << " != " << funcType.getReturnType(); 1053 1054 return success(); 1055 } 1056 1057 static void printCallOp(OpAsmPrinter &p, CallOp &op) { 1058 auto callee = op.getCallee(); 1059 bool isDirect = callee.hasValue(); 1060 1061 // Print the direct callee if present as a function attribute, or an indirect 1062 // callee (first operand) otherwise. 1063 p << ' '; 1064 if (isDirect) 1065 p.printSymbolName(callee.getValue()); 1066 else 1067 p << op.getOperand(0); 1068 1069 auto args = op.getOperands().drop_front(isDirect ? 0 : 1); 1070 p << '(' << args << ')'; 1071 p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"}); 1072 1073 // Reconstruct the function MLIR function type from operand and result types. 1074 p << " : " 1075 << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes()); 1076 } 1077 1078 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1079 // attribute-dict? `:` function-type 1080 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { 1081 SmallVector<OpAsmParser::OperandType, 8> operands; 1082 Type type; 1083 SymbolRefAttr funcAttr; 1084 llvm::SMLoc trailingTypeLoc; 1085 1086 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1087 // case of an indirect call, there will be 1 operand before `(`. In case of a 1088 // direct call, there will be no operands and the parser will stop at the 1089 // function identifier without complaining. 1090 if (parser.parseOperandList(operands)) 1091 return failure(); 1092 bool isDirect = operands.empty(); 1093 1094 // Optionally parse a function identifier. 1095 if (isDirect) 1096 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1097 return failure(); 1098 1099 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1100 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1101 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1102 return failure(); 1103 1104 auto funcType = type.dyn_cast<FunctionType>(); 1105 if (!funcType) 1106 return parser.emitError(trailingTypeLoc, "expected function type"); 1107 if (funcType.getNumResults() > 1) 1108 return parser.emitError(trailingTypeLoc, 1109 "expected function with 0 or 1 result"); 1110 if (isDirect) { 1111 // Make sure types match. 1112 if (parser.resolveOperands(operands, funcType.getInputs(), 1113 parser.getNameLoc(), result.operands)) 1114 return failure(); 1115 if (funcType.getNumResults() != 0 && 1116 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1117 result.addTypes(funcType.getResults()); 1118 } else { 1119 Builder &builder = parser.getBuilder(); 1120 Type llvmResultType; 1121 if (funcType.getNumResults() == 0) { 1122 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1123 } else { 1124 llvmResultType = funcType.getResult(0); 1125 if (!isCompatibleType(llvmResultType)) 1126 return parser.emitError(trailingTypeLoc, 1127 "expected result to have LLVM type"); 1128 } 1129 1130 SmallVector<Type, 8> argTypes; 1131 argTypes.reserve(funcType.getNumInputs()); 1132 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1133 auto argType = funcType.getInput(i); 1134 if (!isCompatibleType(argType)) 1135 return parser.emitError(trailingTypeLoc, 1136 "expected LLVM types as inputs"); 1137 argTypes.push_back(argType); 1138 } 1139 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1140 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1141 1142 auto funcArguments = 1143 ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); 1144 1145 // Make sure that the first operand (indirect callee) matches the wrapped 1146 // LLVM IR function type, and that the types of the other call operands 1147 // match the types of the function arguments. 1148 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1149 parser.resolveOperands(funcArguments, funcType.getInputs(), 1150 parser.getNameLoc(), result.operands)) 1151 return failure(); 1152 1153 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1154 result.addTypes(llvmResultType); 1155 } 1156 1157 return success(); 1158 } 1159 1160 //===----------------------------------------------------------------------===// 1161 // Printing/parsing for LLVM::ExtractElementOp. 1162 //===----------------------------------------------------------------------===// 1163 // Expects vector to be of wrapped LLVM vector type and position to be of 1164 // wrapped LLVM i32 type. 1165 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1166 Value vector, Value position, 1167 ArrayRef<NamedAttribute> attrs) { 1168 auto vectorType = vector.getType(); 1169 auto llvmType = LLVM::getVectorElementType(vectorType); 1170 build(b, result, llvmType, vector, position); 1171 result.addAttributes(attrs); 1172 } 1173 1174 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) { 1175 p << ' ' << op.getVector() << "[" << op.getPosition() << " : " 1176 << op.getPosition().getType() << "]"; 1177 p.printOptionalAttrDict(op->getAttrs()); 1178 p << " : " << op.getVector().getType(); 1179 } 1180 1181 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1182 // attribute-dict? `:` type 1183 static ParseResult parseExtractElementOp(OpAsmParser &parser, 1184 OperationState &result) { 1185 llvm::SMLoc loc; 1186 OpAsmParser::OperandType vector, position; 1187 Type type, positionType; 1188 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1189 parser.parseLSquare() || parser.parseOperand(position) || 1190 parser.parseColonType(positionType) || parser.parseRSquare() || 1191 parser.parseOptionalAttrDict(result.attributes) || 1192 parser.parseColonType(type) || 1193 parser.resolveOperand(vector, type, result.operands) || 1194 parser.resolveOperand(position, positionType, result.operands)) 1195 return failure(); 1196 if (!LLVM::isCompatibleVectorType(type)) 1197 return parser.emitError( 1198 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1199 result.addTypes(LLVM::getVectorElementType(type)); 1200 return success(); 1201 } 1202 1203 static LogicalResult verify(ExtractElementOp op) { 1204 Type vectorType = op.getVector().getType(); 1205 if (!LLVM::isCompatibleVectorType(vectorType)) 1206 return op->emitOpError("expected LLVM dialect-compatible vector type for " 1207 "operand #1, got") 1208 << vectorType; 1209 Type valueType = LLVM::getVectorElementType(vectorType); 1210 if (valueType != op.getRes().getType()) 1211 return op.emitOpError() << "Type mismatch: extracting from " << vectorType 1212 << " should produce " << valueType 1213 << " but this op returns " << op.getRes().getType(); 1214 return success(); 1215 } 1216 1217 //===----------------------------------------------------------------------===// 1218 // Printing/parsing for LLVM::ExtractValueOp. 1219 //===----------------------------------------------------------------------===// 1220 1221 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) { 1222 p << ' ' << op.getContainer() << op.getPosition(); 1223 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 1224 p << " : " << op.getContainer().getType(); 1225 } 1226 1227 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1228 // `containerType`. Position is an integer array attribute where each value 1229 // is a zero-based position of the element in the aggregate type. Return the 1230 // resulting type wrapped in MLIR, or nullptr on error. 1231 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1232 Type containerType, 1233 ArrayAttr positionAttr, 1234 llvm::SMLoc attributeLoc, 1235 llvm::SMLoc typeLoc) { 1236 Type llvmType = containerType; 1237 if (!isCompatibleType(containerType)) 1238 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1239 1240 // Infer the element type from the structure type: iteratively step inside the 1241 // type by taking the element type, indexed by the position attribute for 1242 // structures. Check the position index before accessing, it is supposed to 1243 // be in bounds. 1244 for (Attribute subAttr : positionAttr) { 1245 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1246 if (!positionElementAttr) 1247 return parser.emitError(attributeLoc, 1248 "expected an array of integer literals"), 1249 nullptr; 1250 int position = positionElementAttr.getInt(); 1251 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1252 if (position < 0 || 1253 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1254 return parser.emitError(attributeLoc, "position out of bounds"), 1255 nullptr; 1256 llvmType = arrayType.getElementType(); 1257 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1258 if (position < 0 || 1259 static_cast<unsigned>(position) >= structType.getBody().size()) 1260 return parser.emitError(attributeLoc, "position out of bounds"), 1261 nullptr; 1262 llvmType = structType.getBody()[position]; 1263 } else { 1264 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1265 nullptr; 1266 } 1267 } 1268 return llvmType; 1269 } 1270 1271 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1272 // `containerType`. Returns null on failure. 1273 static Type getInsertExtractValueElementType(Type containerType, 1274 ArrayAttr positionAttr, 1275 Operation *op) { 1276 Type llvmType = containerType; 1277 if (!isCompatibleType(containerType)) { 1278 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1279 return {}; 1280 } 1281 1282 // Infer the element type from the structure type: iteratively step inside the 1283 // type by taking the element type, indexed by the position attribute for 1284 // structures. Check the position index before accessing, it is supposed to 1285 // be in bounds. 1286 for (Attribute subAttr : positionAttr) { 1287 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1288 if (!positionElementAttr) { 1289 op->emitOpError("expected an array of integer literals, got: ") 1290 << subAttr; 1291 return {}; 1292 } 1293 int position = positionElementAttr.getInt(); 1294 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1295 if (position < 0 || 1296 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1297 op->emitOpError("position out of bounds: ") << position; 1298 return {}; 1299 } 1300 llvmType = arrayType.getElementType(); 1301 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1302 if (position < 0 || 1303 static_cast<unsigned>(position) >= structType.getBody().size()) { 1304 op->emitOpError("position out of bounds") << position; 1305 return {}; 1306 } 1307 llvmType = structType.getBody()[position]; 1308 } else { 1309 op->emitOpError("expected LLVM IR structure/array type, got: ") 1310 << llvmType; 1311 return {}; 1312 } 1313 } 1314 return llvmType; 1315 } 1316 1317 // <operation> ::= `llvm.extractvalue` ssa-use 1318 // `[` integer-literal (`,` integer-literal)* `]` 1319 // attribute-dict? `:` type 1320 static ParseResult parseExtractValueOp(OpAsmParser &parser, 1321 OperationState &result) { 1322 OpAsmParser::OperandType container; 1323 Type containerType; 1324 ArrayAttr positionAttr; 1325 llvm::SMLoc attributeLoc, trailingTypeLoc; 1326 1327 if (parser.parseOperand(container) || 1328 parser.getCurrentLocation(&attributeLoc) || 1329 parser.parseAttribute(positionAttr, "position", result.attributes) || 1330 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1331 parser.getCurrentLocation(&trailingTypeLoc) || 1332 parser.parseType(containerType) || 1333 parser.resolveOperand(container, containerType, result.operands)) 1334 return failure(); 1335 1336 auto elementType = getInsertExtractValueElementType( 1337 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1338 if (!elementType) 1339 return failure(); 1340 1341 result.addTypes(elementType); 1342 return success(); 1343 } 1344 1345 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1346 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1347 while (insertValueOp) { 1348 if (getPosition() == insertValueOp.getPosition()) 1349 return insertValueOp.getValue(); 1350 unsigned min = 1351 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1352 // If one is fully prefix of the other, stop propagating back as it will 1353 // miss dependencies. For instance, %3 should not fold to %f0 in the 1354 // following example: 1355 // ``` 1356 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1357 // !llvm.array<4 x !llvm.array<4xf32>> 1358 // %2 = llvm.insertvalue %arr, %1[0] : 1359 // !llvm.array<4 x !llvm.array<4xf32>> 1360 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1361 // ``` 1362 if (getPosition().getValue().take_front(min) == 1363 insertValueOp.getPosition().getValue().take_front(min)) 1364 return {}; 1365 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1366 } 1367 return {}; 1368 } 1369 1370 static LogicalResult verify(ExtractValueOp op) { 1371 Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), 1372 op.getPositionAttr(), op); 1373 if (!valueType) 1374 return failure(); 1375 1376 if (op.getRes().getType() != valueType) 1377 return op.emitOpError() 1378 << "Type mismatch: extracting from " << op.getContainer().getType() 1379 << " should produce " << valueType << " but this op returns " 1380 << op.getRes().getType(); 1381 return success(); 1382 } 1383 1384 //===----------------------------------------------------------------------===// 1385 // Printing/parsing for LLVM::InsertElementOp. 1386 //===----------------------------------------------------------------------===// 1387 1388 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) { 1389 p << ' ' << op.getValue() << ", " << op.getVector() << "[" << op.getPosition() 1390 << " : " << op.getPosition().getType() << "]"; 1391 p.printOptionalAttrDict(op->getAttrs()); 1392 p << " : " << op.getVector().getType(); 1393 } 1394 1395 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1396 // attribute-dict? `:` type 1397 static ParseResult parseInsertElementOp(OpAsmParser &parser, 1398 OperationState &result) { 1399 llvm::SMLoc loc; 1400 OpAsmParser::OperandType vector, value, position; 1401 Type vectorType, positionType; 1402 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1403 parser.parseComma() || parser.parseOperand(vector) || 1404 parser.parseLSquare() || parser.parseOperand(position) || 1405 parser.parseColonType(positionType) || parser.parseRSquare() || 1406 parser.parseOptionalAttrDict(result.attributes) || 1407 parser.parseColonType(vectorType)) 1408 return failure(); 1409 1410 if (!LLVM::isCompatibleVectorType(vectorType)) 1411 return parser.emitError( 1412 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1413 Type valueType = LLVM::getVectorElementType(vectorType); 1414 if (!valueType) 1415 return failure(); 1416 1417 if (parser.resolveOperand(vector, vectorType, result.operands) || 1418 parser.resolveOperand(value, valueType, result.operands) || 1419 parser.resolveOperand(position, positionType, result.operands)) 1420 return failure(); 1421 1422 result.addTypes(vectorType); 1423 return success(); 1424 } 1425 1426 static LogicalResult verify(InsertElementOp op) { 1427 Type valueType = LLVM::getVectorElementType(op.getVector().getType()); 1428 if (valueType != op.getValue().getType()) 1429 return op.emitOpError() 1430 << "Type mismatch: cannot insert " << op.getValue().getType() 1431 << " into " << op.getVector().getType(); 1432 return success(); 1433 } 1434 //===----------------------------------------------------------------------===// 1435 // Printing/parsing for LLVM::InsertValueOp. 1436 //===----------------------------------------------------------------------===// 1437 1438 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) { 1439 p << ' ' << op.getValue() << ", " << op.getContainer() << op.getPosition(); 1440 p.printOptionalAttrDict(op->getAttrs(), {"position"}); 1441 p << " : " << op.getContainer().getType(); 1442 } 1443 1444 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1445 // `[` integer-literal (`,` integer-literal)* `]` 1446 // attribute-dict? `:` type 1447 static ParseResult parseInsertValueOp(OpAsmParser &parser, 1448 OperationState &result) { 1449 OpAsmParser::OperandType container, value; 1450 Type containerType; 1451 ArrayAttr positionAttr; 1452 llvm::SMLoc attributeLoc, trailingTypeLoc; 1453 1454 if (parser.parseOperand(value) || parser.parseComma() || 1455 parser.parseOperand(container) || 1456 parser.getCurrentLocation(&attributeLoc) || 1457 parser.parseAttribute(positionAttr, "position", result.attributes) || 1458 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1459 parser.getCurrentLocation(&trailingTypeLoc) || 1460 parser.parseType(containerType)) 1461 return failure(); 1462 1463 auto valueType = getInsertExtractValueElementType( 1464 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1465 if (!valueType) 1466 return failure(); 1467 1468 if (parser.resolveOperand(container, containerType, result.operands) || 1469 parser.resolveOperand(value, valueType, result.operands)) 1470 return failure(); 1471 1472 result.addTypes(containerType); 1473 return success(); 1474 } 1475 1476 static LogicalResult verify(InsertValueOp op) { 1477 Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), 1478 op.getPositionAttr(), op); 1479 if (!valueType) 1480 return failure(); 1481 1482 if (op.getValue().getType() != valueType) 1483 return op.emitOpError() 1484 << "Type mismatch: cannot insert " << op.getValue().getType() 1485 << " into " << op.getContainer().getType(); 1486 1487 return success(); 1488 } 1489 1490 //===----------------------------------------------------------------------===// 1491 // Printing, parsing and verification for LLVM::ReturnOp. 1492 //===----------------------------------------------------------------------===// 1493 1494 static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { 1495 p.printOptionalAttrDict(op->getAttrs()); 1496 assert(op.getNumOperands() <= 1); 1497 1498 if (op.getNumOperands() == 0) 1499 return; 1500 1501 p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType(); 1502 } 1503 1504 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` 1505 // type-list-no-parens 1506 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { 1507 SmallVector<OpAsmParser::OperandType, 1> operands; 1508 Type type; 1509 1510 if (parser.parseOperandList(operands) || 1511 parser.parseOptionalAttrDict(result.attributes)) 1512 return failure(); 1513 if (operands.empty()) 1514 return success(); 1515 1516 if (parser.parseColonType(type) || 1517 parser.resolveOperand(operands[0], type, result.operands)) 1518 return failure(); 1519 return success(); 1520 } 1521 1522 static LogicalResult verify(ReturnOp op) { 1523 if (op->getNumOperands() > 1) 1524 return op->emitOpError("expected at most 1 operand"); 1525 1526 if (auto parent = op->getParentOfType<LLVMFuncOp>()) { 1527 Type expectedType = parent.getType().getReturnType(); 1528 if (expectedType.isa<LLVMVoidType>()) { 1529 if (op->getNumOperands() == 0) 1530 return success(); 1531 InFlightDiagnostic diag = op->emitOpError("expected no operands"); 1532 diag.attachNote(parent->getLoc()) << "when returning from function"; 1533 return diag; 1534 } 1535 if (op->getNumOperands() == 0) { 1536 if (expectedType.isa<LLVMVoidType>()) 1537 return success(); 1538 InFlightDiagnostic diag = op->emitOpError("expected 1 operand"); 1539 diag.attachNote(parent->getLoc()) << "when returning from function"; 1540 return diag; 1541 } 1542 if (expectedType != op->getOperand(0).getType()) { 1543 InFlightDiagnostic diag = op->emitOpError("mismatching result types"); 1544 diag.attachNote(parent->getLoc()) << "when returning from function"; 1545 return diag; 1546 } 1547 } 1548 return success(); 1549 } 1550 1551 //===----------------------------------------------------------------------===// 1552 // Verifier for LLVM::AddressOfOp. 1553 //===----------------------------------------------------------------------===// 1554 1555 template <typename OpTy> 1556 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1557 Operation *module = parent; 1558 while (module && !satisfiesLLVMModule(module)) 1559 module = module->getParentOp(); 1560 assert(module && "unexpected operation outside of a module"); 1561 return dyn_cast_or_null<OpTy>( 1562 mlir::SymbolTable::lookupSymbolIn(module, name)); 1563 } 1564 1565 GlobalOp AddressOfOp::getGlobal() { 1566 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1567 getGlobalName()); 1568 } 1569 1570 LLVMFuncOp AddressOfOp::getFunction() { 1571 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1572 getGlobalName()); 1573 } 1574 1575 static LogicalResult verify(AddressOfOp op) { 1576 auto global = op.getGlobal(); 1577 auto function = op.getFunction(); 1578 if (!global && !function) 1579 return op.emitOpError( 1580 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1581 1582 if (global && 1583 LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) != 1584 op.getResult().getType()) 1585 return op.emitOpError( 1586 "the type must be a pointer to the type of the referenced global"); 1587 1588 if (function && LLVM::LLVMPointerType::get(function.getType()) != 1589 op.getResult().getType()) 1590 return op.emitOpError( 1591 "the type must be a pointer to the type of the referenced function"); 1592 1593 return success(); 1594 } 1595 1596 //===----------------------------------------------------------------------===// 1597 // Builder, printer and verifier for LLVM::GlobalOp. 1598 //===----------------------------------------------------------------------===// 1599 1600 /// Returns the name used for the linkage attribute. This *must* correspond to 1601 /// the name of the attribute in ODS. 1602 static StringRef getLinkageAttrName() { return "linkage"; } 1603 1604 /// Returns the name used for the unnamed_addr attribute. This *must* correspond 1605 /// to the name of the attribute in ODS. 1606 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; } 1607 1608 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1609 bool isConstant, Linkage linkage, StringRef name, 1610 Attribute value, uint64_t alignment, unsigned addrSpace, 1611 bool dsoLocal, ArrayRef<NamedAttribute> attrs) { 1612 result.addAttribute(SymbolTable::getSymbolAttrName(), 1613 builder.getStringAttr(name)); 1614 result.addAttribute("global_type", TypeAttr::get(type)); 1615 if (isConstant) 1616 result.addAttribute("constant", builder.getUnitAttr()); 1617 if (value) 1618 result.addAttribute("value", value); 1619 if (dsoLocal) 1620 result.addAttribute("dso_local", builder.getUnitAttr()); 1621 1622 // Only add an alignment attribute if the "alignment" input 1623 // is different from 0. The value must also be a power of two, but 1624 // this is tested in GlobalOp::verify, not here. 1625 if (alignment != 0) 1626 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 1627 1628 result.addAttribute(::getLinkageAttrName(), 1629 LinkageAttr::get(builder.getContext(), linkage)); 1630 if (addrSpace != 0) 1631 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace)); 1632 result.attributes.append(attrs.begin(), attrs.end()); 1633 result.addRegion(); 1634 } 1635 1636 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { 1637 p << ' ' << stringifyLinkage(op.getLinkage()) << ' '; 1638 if (auto unnamedAddr = op.getUnnamedAddr()) { 1639 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1640 if (!str.empty()) 1641 p << str << ' '; 1642 } 1643 if (op.getConstant()) 1644 p << "constant "; 1645 p.printSymbolName(op.getSymName()); 1646 p << '('; 1647 if (auto value = op.getValueOrNull()) 1648 p.printAttribute(value); 1649 p << ')'; 1650 // Note that the alignment attribute is printed using the 1651 // default syntax here, even though it is an inherent attribute 1652 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1653 p.printOptionalAttrDict(op->getAttrs(), 1654 {SymbolTable::getSymbolAttrName(), "global_type", 1655 "constant", "value", getLinkageAttrName(), 1656 getUnnamedAddrAttrName()}); 1657 1658 // Print the trailing type unless it's a string global. 1659 if (op.getValueOrNull().dyn_cast_or_null<StringAttr>()) 1660 return; 1661 p << " : " << op.getType(); 1662 1663 Region &initializer = op.getInitializerRegion(); 1664 if (!initializer.empty()) { 1665 p << ' '; 1666 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1667 } 1668 } 1669 1670 // Parses one of the keywords provided in the list `keywords` and returns the 1671 // position of the parsed keyword in the list. If none of the keywords from the 1672 // list is parsed, returns -1. 1673 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1674 ArrayRef<StringRef> keywords) { 1675 for (const auto &en : llvm::enumerate(keywords)) { 1676 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1677 return en.index(); 1678 } 1679 return -1; 1680 } 1681 1682 namespace { 1683 template <typename Ty> 1684 struct EnumTraits {}; 1685 1686 #define REGISTER_ENUM_TYPE(Ty) \ 1687 template <> \ 1688 struct EnumTraits<Ty> { \ 1689 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1690 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1691 } 1692 1693 REGISTER_ENUM_TYPE(Linkage); 1694 REGISTER_ENUM_TYPE(UnnamedAddr); 1695 } // namespace 1696 1697 /// Parse an enum from the keyword, or default to the provided default value. 1698 /// The return type is the enum type by default, unless overriden with the 1699 /// second template argument. 1700 template <typename EnumTy, typename RetTy = EnumTy> 1701 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1702 OperationState &result, 1703 EnumTy defaultValue) { 1704 SmallVector<StringRef, 10> names; 1705 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1706 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1707 1708 int index = parseOptionalKeywordAlternative(parser, names); 1709 if (index == -1) 1710 return static_cast<RetTy>(defaultValue); 1711 return static_cast<RetTy>(index); 1712 } 1713 1714 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1715 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1716 // align ::= `align` `=` UINT64 1717 // 1718 // The type can be omitted for string attributes, in which case it will be 1719 // inferred from the value of the string as [strlen(value) x i8]. 1720 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 1721 MLIRContext *ctx = parser.getContext(); 1722 // Parse optional linkage, default to External. 1723 result.addAttribute(getLinkageAttrName(), 1724 LLVM::LinkageAttr::get( 1725 ctx, parseOptionalLLVMKeyword<Linkage>( 1726 parser, result, LLVM::Linkage::External))); 1727 // Parse optional UnnamedAddr, default to None. 1728 result.addAttribute(getUnnamedAddrAttrName(), 1729 parser.getBuilder().getI64IntegerAttr( 1730 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1731 parser, result, LLVM::UnnamedAddr::None))); 1732 1733 if (succeeded(parser.parseOptionalKeyword("constant"))) 1734 result.addAttribute("constant", parser.getBuilder().getUnitAttr()); 1735 1736 StringAttr name; 1737 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(), 1738 result.attributes) || 1739 parser.parseLParen()) 1740 return failure(); 1741 1742 Attribute value; 1743 if (parser.parseOptionalRParen()) { 1744 if (parser.parseAttribute(value, "value", result.attributes) || 1745 parser.parseRParen()) 1746 return failure(); 1747 } 1748 1749 SmallVector<Type, 1> types; 1750 if (parser.parseOptionalAttrDict(result.attributes) || 1751 parser.parseOptionalColonTypeList(types)) 1752 return failure(); 1753 1754 if (types.size() > 1) 1755 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1756 1757 Region &initRegion = *result.addRegion(); 1758 if (types.empty()) { 1759 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1760 MLIRContext *context = parser.getContext(); 1761 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1762 strAttr.getValue().size()); 1763 types.push_back(arrayType); 1764 } else { 1765 return parser.emitError(parser.getNameLoc(), 1766 "type can only be omitted for string globals"); 1767 } 1768 } else { 1769 OptionalParseResult parseResult = 1770 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1771 /*argTypes=*/{}); 1772 if (parseResult.hasValue() && failed(*parseResult)) 1773 return failure(); 1774 } 1775 1776 result.addAttribute("global_type", TypeAttr::get(types[0])); 1777 return success(); 1778 } 1779 1780 static bool isZeroAttribute(Attribute value) { 1781 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1782 return intValue.getValue().isNullValue(); 1783 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1784 return fpValue.getValue().isZero(); 1785 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1786 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1787 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1788 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1789 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1790 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1791 return false; 1792 } 1793 1794 static LogicalResult verify(GlobalOp op) { 1795 if (!LLVMPointerType::isValidElementType(op.getType())) 1796 return op.emitOpError( 1797 "expects type to be a valid element type for an LLVM pointer"); 1798 if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp())) 1799 return op.emitOpError("must appear at the module level"); 1800 1801 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1802 auto type = op.getType().dyn_cast<LLVMArrayType>(); 1803 IntegerType elementType = 1804 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1805 if (!elementType || elementType.getWidth() != 8 || 1806 type.getNumElements() != strAttr.getValue().size()) 1807 return op.emitOpError( 1808 "requires an i8 array type of the length equal to that of the string " 1809 "attribute"); 1810 } 1811 1812 if (Block *b = op.getInitializerBlock()) { 1813 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1814 if (ret.operand_type_begin() == ret.operand_type_end()) 1815 return op.emitOpError("initializer region cannot return void"); 1816 if (*ret.operand_type_begin() != op.getType()) 1817 return op.emitOpError("initializer region type ") 1818 << *ret.operand_type_begin() << " does not match global type " 1819 << op.getType(); 1820 1821 if (op.getValueOrNull()) 1822 return op.emitOpError("cannot have both initializer value and region"); 1823 } 1824 1825 if (op.getLinkage() == Linkage::Common) { 1826 if (Attribute value = op.getValueOrNull()) { 1827 if (!isZeroAttribute(value)) { 1828 return op.emitOpError() 1829 << "expected zero value for '" 1830 << stringifyLinkage(Linkage::Common) << "' linkage"; 1831 } 1832 } 1833 } 1834 1835 if (op.getLinkage() == Linkage::Appending) { 1836 if (!op.getType().isa<LLVMArrayType>()) { 1837 return op.emitOpError() 1838 << "expected array type for '" 1839 << stringifyLinkage(Linkage::Appending) << "' linkage"; 1840 } 1841 } 1842 1843 Optional<uint64_t> alignAttr = op.getAlignment(); 1844 if (alignAttr.hasValue()) { 1845 uint64_t value = alignAttr.getValue(); 1846 if (!llvm::isPowerOf2_64(value)) 1847 return op->emitError() << "alignment attribute is not a power of 2"; 1848 } 1849 1850 return success(); 1851 } 1852 1853 //===----------------------------------------------------------------------===// 1854 // LLVM::GlobalCtorsOp 1855 //===----------------------------------------------------------------------===// 1856 1857 LogicalResult 1858 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1859 for (Attribute ctor : getCtors()) { 1860 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 1861 symbolTable))) 1862 return failure(); 1863 } 1864 return success(); 1865 } 1866 1867 static LogicalResult verify(GlobalCtorsOp op) { 1868 if (op.getCtors().size() != op.getPriorities().size()) 1869 return op.emitError( 1870 "mismatch between the number of ctors and the number of priorities"); 1871 return success(); 1872 } 1873 1874 //===----------------------------------------------------------------------===// 1875 // LLVM::GlobalDtorsOp 1876 //===----------------------------------------------------------------------===// 1877 1878 LogicalResult 1879 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 1880 for (Attribute dtor : getDtors()) { 1881 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 1882 symbolTable))) 1883 return failure(); 1884 } 1885 return success(); 1886 } 1887 1888 static LogicalResult verify(GlobalDtorsOp op) { 1889 if (op.getDtors().size() != op.getPriorities().size()) 1890 return op.emitError( 1891 "mismatch between the number of dtors and the number of priorities"); 1892 return success(); 1893 } 1894 1895 //===----------------------------------------------------------------------===// 1896 // Printing/parsing for LLVM::ShuffleVectorOp. 1897 //===----------------------------------------------------------------------===// 1898 // Expects vector to be of wrapped LLVM vector type and position to be of 1899 // wrapped LLVM i32 type. 1900 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 1901 Value v1, Value v2, ArrayAttr mask, 1902 ArrayRef<NamedAttribute> attrs) { 1903 auto containerType = v1.getType(); 1904 auto vType = LLVM::getFixedVectorType( 1905 LLVM::getVectorElementType(containerType), mask.size()); 1906 build(b, result, vType, v1, v2, mask); 1907 result.addAttributes(attrs); 1908 } 1909 1910 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) { 1911 p << ' ' << op.getV1() << ", " << op.getV2() << " " << op.getMask(); 1912 p.printOptionalAttrDict(op->getAttrs(), {"mask"}); 1913 p << " : " << op.getV1().getType() << ", " << op.getV2().getType(); 1914 } 1915 1916 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 1917 // `[` integer-literal (`,` integer-literal)* `]` 1918 // attribute-dict? `:` type 1919 static ParseResult parseShuffleVectorOp(OpAsmParser &parser, 1920 OperationState &result) { 1921 llvm::SMLoc loc; 1922 OpAsmParser::OperandType v1, v2; 1923 ArrayAttr maskAttr; 1924 Type typeV1, typeV2; 1925 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 1926 parser.parseComma() || parser.parseOperand(v2) || 1927 parser.parseAttribute(maskAttr, "mask", result.attributes) || 1928 parser.parseOptionalAttrDict(result.attributes) || 1929 parser.parseColonType(typeV1) || parser.parseComma() || 1930 parser.parseType(typeV2) || 1931 parser.resolveOperand(v1, typeV1, result.operands) || 1932 parser.resolveOperand(v2, typeV2, result.operands)) 1933 return failure(); 1934 if (!LLVM::isCompatibleVectorType(typeV1)) 1935 return parser.emitError( 1936 loc, "expected LLVM IR dialect vector type for operand #1"); 1937 auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1), 1938 maskAttr.size()); 1939 result.addTypes(vType); 1940 return success(); 1941 } 1942 1943 //===----------------------------------------------------------------------===// 1944 // Implementations for LLVM::LLVMFuncOp. 1945 //===----------------------------------------------------------------------===// 1946 1947 // Add the entry block to the function. 1948 Block *LLVMFuncOp::addEntryBlock() { 1949 assert(empty() && "function already has an entry block"); 1950 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 1951 1952 auto *entry = new Block; 1953 push_back(entry); 1954 1955 LLVMFunctionType type = getType(); 1956 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 1957 entry->addArgument(type.getParamType(i)); 1958 return entry; 1959 } 1960 1961 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 1962 StringRef name, Type type, LLVM::Linkage linkage, 1963 bool dsoLocal, ArrayRef<NamedAttribute> attrs, 1964 ArrayRef<DictionaryAttr> argAttrs) { 1965 result.addRegion(); 1966 result.addAttribute(SymbolTable::getSymbolAttrName(), 1967 builder.getStringAttr(name)); 1968 result.addAttribute("type", TypeAttr::get(type)); 1969 result.addAttribute(::getLinkageAttrName(), 1970 LinkageAttr::get(builder.getContext(), linkage)); 1971 result.attributes.append(attrs.begin(), attrs.end()); 1972 if (dsoLocal) 1973 result.addAttribute("dso_local", builder.getUnitAttr()); 1974 if (argAttrs.empty()) 1975 return; 1976 1977 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 1978 "expected as many argument attribute lists as arguments"); 1979 function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, 1980 /*resultAttrs=*/llvm::None); 1981 } 1982 1983 // Builds an LLVM function type from the given lists of input and output types. 1984 // Returns a null type if any of the types provided are non-LLVM types, or if 1985 // there is more than one output type. 1986 static Type 1987 buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, 1988 ArrayRef<Type> inputs, ArrayRef<Type> outputs, 1989 function_like_impl::VariadicFlag variadicFlag) { 1990 Builder &b = parser.getBuilder(); 1991 if (outputs.size() > 1) { 1992 parser.emitError(loc, "failed to construct function type: expected zero or " 1993 "one function result"); 1994 return {}; 1995 } 1996 1997 // Convert inputs to LLVM types, exit early on error. 1998 SmallVector<Type, 4> llvmInputs; 1999 for (auto t : inputs) { 2000 if (!isCompatibleType(t)) { 2001 parser.emitError(loc, "failed to construct function type: expected LLVM " 2002 "type for function arguments"); 2003 return {}; 2004 } 2005 llvmInputs.push_back(t); 2006 } 2007 2008 // No output is denoted as "void" in LLVM type system. 2009 Type llvmOutput = 2010 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2011 if (!isCompatibleType(llvmOutput)) { 2012 parser.emitError(loc, "failed to construct function type: expected LLVM " 2013 "type for function results") 2014 << llvmOutput; 2015 return {}; 2016 } 2017 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2018 variadicFlag.isVariadic()); 2019 } 2020 2021 // Parses an LLVM function. 2022 // 2023 // operation ::= `llvm.func` linkage? function-signature function-attributes? 2024 // function-body 2025 // 2026 static ParseResult parseLLVMFuncOp(OpAsmParser &parser, 2027 OperationState &result) { 2028 // Default to external linkage if no keyword is provided. 2029 result.addAttribute( 2030 getLinkageAttrName(), 2031 LinkageAttr::get(parser.getContext(), 2032 parseOptionalLLVMKeyword<Linkage>( 2033 parser, result, LLVM::Linkage::External))); 2034 2035 StringAttr nameAttr; 2036 SmallVector<OpAsmParser::OperandType, 8> entryArgs; 2037 SmallVector<NamedAttrList, 1> argAttrs; 2038 SmallVector<NamedAttrList, 1> resultAttrs; 2039 SmallVector<Type, 8> argTypes; 2040 SmallVector<Type, 4> resultTypes; 2041 bool isVariadic; 2042 2043 auto signatureLocation = parser.getCurrentLocation(); 2044 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2045 result.attributes) || 2046 function_like_impl::parseFunctionSignature( 2047 parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, 2048 isVariadic, resultTypes, resultAttrs)) 2049 return failure(); 2050 2051 auto type = 2052 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2053 function_like_impl::VariadicFlag(isVariadic)); 2054 if (!type) 2055 return failure(); 2056 result.addAttribute(function_like_impl::getTypeAttrName(), 2057 TypeAttr::get(type)); 2058 2059 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2060 return failure(); 2061 function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2062 argAttrs, resultAttrs); 2063 2064 auto *body = result.addRegion(); 2065 OptionalParseResult parseResult = parser.parseOptionalRegion( 2066 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes); 2067 return failure(parseResult.hasValue() && failed(*parseResult)); 2068 } 2069 2070 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2071 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2072 // the external linkage since it is the default value. 2073 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { 2074 p << ' '; 2075 if (op.getLinkage() != LLVM::Linkage::External) 2076 p << stringifyLinkage(op.getLinkage()) << ' '; 2077 p.printSymbolName(op.getName()); 2078 2079 LLVMFunctionType fnType = op.getType(); 2080 SmallVector<Type, 8> argTypes; 2081 SmallVector<Type, 1> resTypes; 2082 argTypes.reserve(fnType.getNumParams()); 2083 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2084 argTypes.push_back(fnType.getParamType(i)); 2085 2086 Type returnType = fnType.getReturnType(); 2087 if (!returnType.isa<LLVMVoidType>()) 2088 resTypes.push_back(returnType); 2089 2090 function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), 2091 resTypes); 2092 function_like_impl::printFunctionAttributes( 2093 p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); 2094 2095 // Print the body if this is not an external function. 2096 Region &body = op.getBody(); 2097 if (!body.empty()) { 2098 p << ' '; 2099 p.printRegion(body, /*printEntryBlockArgs=*/false, 2100 /*printBlockTerminators=*/true); 2101 } 2102 } 2103 2104 // Hook for OpTrait::FunctionLike, called after verifying that the 'type' 2105 // attribute is present. This can check for preconditions of the 2106 // getNumArguments hook not failing. 2107 LogicalResult LLVMFuncOp::verifyType() { 2108 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>(); 2109 if (!llvmType) 2110 return emitOpError("requires '" + getTypeAttrName() + 2111 "' attribute of wrapped LLVM function type"); 2112 2113 return success(); 2114 } 2115 2116 // Hook for OpTrait::FunctionLike, returns the number of function arguments. 2117 // Depends on the type attribute being correct as checked by verifyType 2118 unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); } 2119 2120 // Hook for OpTrait::FunctionLike, returns the number of function results. 2121 // Depends on the type attribute being correct as checked by verifyType 2122 unsigned LLVMFuncOp::getNumFuncResults() { 2123 // We model LLVM functions that return void as having zero results, 2124 // and all others as having one result. 2125 // If we modeled a void return as one result, then it would be possible to 2126 // attach an MLIR result attribute to it, and it isn't clear what semantics we 2127 // would assign to that. 2128 if (getType().getReturnType().isa<LLVMVoidType>()) 2129 return 0; 2130 return 1; 2131 } 2132 2133 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2134 // - functions don't have 'common' linkage 2135 // - external functions have 'external' or 'extern_weak' linkage; 2136 // - vararg is (currently) only supported for external functions; 2137 // - entry block arguments are of LLVM types and match the function signature. 2138 static LogicalResult verify(LLVMFuncOp op) { 2139 if (op.getLinkage() == LLVM::Linkage::Common) 2140 return op.emitOpError() 2141 << "functions cannot have '" 2142 << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; 2143 2144 if (op.isExternal()) { 2145 if (op.getLinkage() != LLVM::Linkage::External && 2146 op.getLinkage() != LLVM::Linkage::ExternWeak) 2147 return op.emitOpError() 2148 << "external functions must have '" 2149 << stringifyLinkage(LLVM::Linkage::External) << "' or '" 2150 << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; 2151 return success(); 2152 } 2153 2154 if (op.isVarArg()) 2155 return op.emitOpError("only external functions can be variadic"); 2156 2157 unsigned numArguments = op.getType().getNumParams(); 2158 Block &entryBlock = op.front(); 2159 for (unsigned i = 0; i < numArguments; ++i) { 2160 Type argType = entryBlock.getArgument(i).getType(); 2161 if (!isCompatibleType(argType)) 2162 return op.emitOpError("entry block argument #") 2163 << i << " is not of LLVM type"; 2164 if (op.getType().getParamType(i) != argType) 2165 return op.emitOpError("the type of entry block argument #") 2166 << i << " does not match the function signature"; 2167 } 2168 2169 return success(); 2170 } 2171 2172 //===----------------------------------------------------------------------===// 2173 // Verification for LLVM::ConstantOp. 2174 //===----------------------------------------------------------------------===// 2175 2176 static LogicalResult verify(LLVM::ConstantOp op) { 2177 if (StringAttr sAttr = op.getValue().dyn_cast<StringAttr>()) { 2178 auto arrayType = op.getType().dyn_cast<LLVMArrayType>(); 2179 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2180 !arrayType.getElementType().isInteger(8)) { 2181 return op->emitOpError() 2182 << "expected array type of " << sAttr.getValue().size() 2183 << " i8 elements for the string constant"; 2184 } 2185 return success(); 2186 } 2187 if (auto structType = op.getType().dyn_cast<LLVMStructType>()) { 2188 if (structType.getBody().size() != 2 || 2189 structType.getBody()[0] != structType.getBody()[1]) { 2190 return op.emitError() << "expected struct type with two elements of the " 2191 "same type, the type of a complex constant"; 2192 } 2193 2194 auto arrayAttr = op.getValue().dyn_cast<ArrayAttr>(); 2195 if (!arrayAttr || arrayAttr.size() != 2 || 2196 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2197 return op.emitOpError() << "expected array attribute with two elements, " 2198 "representing a complex constant"; 2199 } 2200 2201 Type elementType = structType.getBody()[0]; 2202 if (!elementType 2203 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2204 return op.emitError() 2205 << "expected struct element types to be floating point type or " 2206 "integer type"; 2207 } 2208 return success(); 2209 } 2210 if (!op.getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2211 return op.emitOpError() 2212 << "only supports integer, float, string or elements attributes"; 2213 return success(); 2214 } 2215 2216 // Constant op constant-folds to its value. 2217 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2218 2219 //===----------------------------------------------------------------------===// 2220 // Utility functions for parsing atomic ops 2221 //===----------------------------------------------------------------------===// 2222 2223 // Helper function to parse a keyword into the specified attribute named by 2224 // `attrName`. The keyword must match one of the string values defined by the 2225 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2226 // state. 2227 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2228 StringRef attrName) { 2229 llvm::SMLoc loc; 2230 StringRef keyword; 2231 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2232 return failure(); 2233 2234 // Replace the keyword `keyword` with an integer attribute. 2235 auto kind = symbolizeAtomicBinOp(keyword); 2236 if (!kind) { 2237 return parser.emitError(loc) 2238 << "'" << keyword << "' is an incorrect value of the '" << attrName 2239 << "' attribute"; 2240 } 2241 2242 auto value = static_cast<int64_t>(kind.getValue()); 2243 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2244 result.addAttribute(attrName, attr); 2245 2246 return success(); 2247 } 2248 2249 // Helper function to parse a keyword into the specified attribute named by 2250 // `attrName`. The keyword must match one of the string values defined by the 2251 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2252 // state. 2253 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2254 OperationState &result, 2255 StringRef attrName) { 2256 llvm::SMLoc loc; 2257 StringRef ordering; 2258 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2259 return failure(); 2260 2261 // Replace the keyword `ordering` with an integer attribute. 2262 auto kind = symbolizeAtomicOrdering(ordering); 2263 if (!kind) { 2264 return parser.emitError(loc) 2265 << "'" << ordering << "' is an incorrect value of the '" << attrName 2266 << "' attribute"; 2267 } 2268 2269 auto value = static_cast<int64_t>(kind.getValue()); 2270 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2271 result.addAttribute(attrName, attr); 2272 2273 return success(); 2274 } 2275 2276 //===----------------------------------------------------------------------===// 2277 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2278 //===----------------------------------------------------------------------===// 2279 2280 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) { 2281 p << ' ' << stringifyAtomicBinOp(op.getBinOp()) << ' ' << op.getPtr() << ", " 2282 << op.getVal() << ' ' << stringifyAtomicOrdering(op.getOrdering()) << ' '; 2283 p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"}); 2284 p << " : " << op.getRes().getType(); 2285 } 2286 2287 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2288 // attribute-dict? `:` type 2289 static ParseResult parseAtomicRMWOp(OpAsmParser &parser, 2290 OperationState &result) { 2291 Type type; 2292 OpAsmParser::OperandType ptr, val; 2293 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2294 parser.parseComma() || parser.parseOperand(val) || 2295 parseAtomicOrdering(parser, result, "ordering") || 2296 parser.parseOptionalAttrDict(result.attributes) || 2297 parser.parseColonType(type) || 2298 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2299 result.operands) || 2300 parser.resolveOperand(val, type, result.operands)) 2301 return failure(); 2302 2303 result.addTypes(type); 2304 return success(); 2305 } 2306 2307 static LogicalResult verify(AtomicRMWOp op) { 2308 auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>(); 2309 auto valType = op.getVal().getType(); 2310 if (valType != ptrType.getElementType()) 2311 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2312 "match type for operand #1"); 2313 auto resType = op.getRes().getType(); 2314 if (resType != valType) 2315 return op.emitOpError( 2316 "expected LLVM IR result type to match type for operand #1"); 2317 if (op.getBinOp() == AtomicBinOp::fadd || 2318 op.getBinOp() == AtomicBinOp::fsub) { 2319 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2320 return op.emitOpError("expected LLVM IR floating point type"); 2321 } else if (op.getBinOp() == AtomicBinOp::xchg) { 2322 auto intType = valType.dyn_cast<IntegerType>(); 2323 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2324 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2325 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2326 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2327 !valType.isa<Float64Type>()) 2328 return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2329 } else { 2330 auto intType = valType.dyn_cast<IntegerType>(); 2331 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2332 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2333 intBitWidth != 64) 2334 return op.emitOpError("expected LLVM IR integer type"); 2335 } 2336 2337 if (static_cast<unsigned>(op.getOrdering()) < 2338 static_cast<unsigned>(AtomicOrdering::monotonic)) 2339 return op.emitOpError() 2340 << "expected at least '" 2341 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2342 << "' ordering"; 2343 2344 return success(); 2345 } 2346 2347 //===----------------------------------------------------------------------===// 2348 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2349 //===----------------------------------------------------------------------===// 2350 2351 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) { 2352 p << ' ' << op.getPtr() << ", " << op.getCmp() << ", " << op.getVal() << ' ' 2353 << stringifyAtomicOrdering(op.getSuccessOrdering()) << ' ' 2354 << stringifyAtomicOrdering(op.getFailureOrdering()); 2355 p.printOptionalAttrDict(op->getAttrs(), 2356 {"success_ordering", "failure_ordering"}); 2357 p << " : " << op.getVal().getType(); 2358 } 2359 2360 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2361 // keyword keyword attribute-dict? `:` type 2362 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, 2363 OperationState &result) { 2364 auto &builder = parser.getBuilder(); 2365 Type type; 2366 OpAsmParser::OperandType ptr, cmp, val; 2367 if (parser.parseOperand(ptr) || parser.parseComma() || 2368 parser.parseOperand(cmp) || parser.parseComma() || 2369 parser.parseOperand(val) || 2370 parseAtomicOrdering(parser, result, "success_ordering") || 2371 parseAtomicOrdering(parser, result, "failure_ordering") || 2372 parser.parseOptionalAttrDict(result.attributes) || 2373 parser.parseColonType(type) || 2374 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2375 result.operands) || 2376 parser.resolveOperand(cmp, type, result.operands) || 2377 parser.resolveOperand(val, type, result.operands)) 2378 return failure(); 2379 2380 auto boolType = IntegerType::get(builder.getContext(), 1); 2381 auto resultType = 2382 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2383 result.addTypes(resultType); 2384 2385 return success(); 2386 } 2387 2388 static LogicalResult verify(AtomicCmpXchgOp op) { 2389 auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>(); 2390 if (!ptrType) 2391 return op.emitOpError("expected LLVM IR pointer type for operand #0"); 2392 auto cmpType = op.getCmp().getType(); 2393 auto valType = op.getVal().getType(); 2394 if (cmpType != ptrType.getElementType() || cmpType != valType) 2395 return op.emitOpError("expected LLVM IR element type for operand #0 to " 2396 "match type for all other operands"); 2397 auto intType = valType.dyn_cast<IntegerType>(); 2398 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2399 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2400 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2401 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2402 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2403 return op.emitOpError("unexpected LLVM IR type"); 2404 if (op.getSuccessOrdering() < AtomicOrdering::monotonic || 2405 op.getFailureOrdering() < AtomicOrdering::monotonic) 2406 return op.emitOpError("ordering must be at least 'monotonic'"); 2407 if (op.getFailureOrdering() == AtomicOrdering::release || 2408 op.getFailureOrdering() == AtomicOrdering::acq_rel) 2409 return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2410 return success(); 2411 } 2412 2413 //===----------------------------------------------------------------------===// 2414 // Printer, parser and verifier for LLVM::FenceOp. 2415 //===----------------------------------------------------------------------===// 2416 2417 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2418 // attribute-dict? 2419 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) { 2420 StringAttr sScope; 2421 StringRef syncscopeKeyword = "syncscope"; 2422 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2423 if (parser.parseLParen() || 2424 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2425 parser.parseRParen()) 2426 return failure(); 2427 } else { 2428 result.addAttribute(syncscopeKeyword, 2429 parser.getBuilder().getStringAttr("")); 2430 } 2431 if (parseAtomicOrdering(parser, result, "ordering") || 2432 parser.parseOptionalAttrDict(result.attributes)) 2433 return failure(); 2434 return success(); 2435 } 2436 2437 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { 2438 StringRef syncscopeKeyword = "syncscope"; 2439 p << ' '; 2440 if (!op->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2441 p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") "; 2442 p << stringifyAtomicOrdering(op.getOrdering()); 2443 } 2444 2445 static LogicalResult verify(FenceOp &op) { 2446 if (op.getOrdering() == AtomicOrdering::not_atomic || 2447 op.getOrdering() == AtomicOrdering::unordered || 2448 op.getOrdering() == AtomicOrdering::monotonic) 2449 return op.emitOpError("can be given only acquire, release, acq_rel, " 2450 "and seq_cst orderings"); 2451 return success(); 2452 } 2453 2454 //===----------------------------------------------------------------------===// 2455 // Folder for LLVM::BitcastOp 2456 //===----------------------------------------------------------------------===// 2457 2458 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2459 // bitcast(x : T0, T0) -> x 2460 if (getArg().getType() == getType()) 2461 return getArg(); 2462 // bitcast(bitcast(x : T0, T1), T0) -> x 2463 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2464 if (prev.getArg().getType() == getType()) 2465 return prev.getArg(); 2466 return {}; 2467 } 2468 2469 //===----------------------------------------------------------------------===// 2470 // Folder for LLVM::AddrSpaceCastOp 2471 //===----------------------------------------------------------------------===// 2472 2473 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2474 // addrcast(x : T0, T0) -> x 2475 if (getArg().getType() == getType()) 2476 return getArg(); 2477 // addrcast(addrcast(x : T0, T1), T0) -> x 2478 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2479 if (prev.getArg().getType() == getType()) 2480 return prev.getArg(); 2481 return {}; 2482 } 2483 2484 //===----------------------------------------------------------------------===// 2485 // Folder for LLVM::GEPOp 2486 //===----------------------------------------------------------------------===// 2487 2488 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2489 // gep %x:T, 0 -> %x 2490 if (getBase().getType() == getType() && getIndices().size() == 1 && 2491 matchPattern(getIndices()[0], m_Zero())) 2492 return getBase(); 2493 return {}; 2494 } 2495 2496 //===----------------------------------------------------------------------===// 2497 // LLVMDialect initialization, type parsing, and registration. 2498 //===----------------------------------------------------------------------===// 2499 2500 void LLVMDialect::initialize() { 2501 addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>(); 2502 2503 // clang-format off 2504 addTypes<LLVMVoidType, 2505 LLVMPPCFP128Type, 2506 LLVMX86MMXType, 2507 LLVMTokenType, 2508 LLVMLabelType, 2509 LLVMMetadataType, 2510 LLVMFunctionType, 2511 LLVMPointerType, 2512 LLVMFixedVectorType, 2513 LLVMScalableVectorType, 2514 LLVMArrayType, 2515 LLVMStructType>(); 2516 // clang-format on 2517 addOperations< 2518 #define GET_OP_LIST 2519 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2520 >(); 2521 2522 // Support unknown operations because not all LLVM operations are registered. 2523 allowUnknownOperations(); 2524 } 2525 2526 #define GET_OP_CLASSES 2527 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2528 2529 /// Parse a type registered to this dialect. 2530 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2531 return detail::parseType(parser); 2532 } 2533 2534 /// Print a type registered to this dialect. 2535 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2536 return detail::printType(type, os); 2537 } 2538 2539 LogicalResult LLVMDialect::verifyDataLayoutString( 2540 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2541 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2542 llvm::DataLayout::parse(descr); 2543 if (maybeDataLayout) 2544 return success(); 2545 2546 std::string message; 2547 llvm::raw_string_ostream messageStream(message); 2548 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2549 reportError("invalid data layout descriptor: " + messageStream.str()); 2550 return failure(); 2551 } 2552 2553 /// Verify LLVM dialect attributes. 2554 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2555 NamedAttribute attr) { 2556 // If the `llvm.loop` attribute is present, enforce the following structure, 2557 // which the module translation can assume. 2558 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2559 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2560 if (!loopAttr) 2561 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2562 << "' to be a dictionary attribute"; 2563 Optional<NamedAttribute> parallelAccessGroup = 2564 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2565 if (parallelAccessGroup.hasValue()) { 2566 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2567 if (!accessGroups) 2568 return op->emitOpError() 2569 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2570 << "' to be an array attribute"; 2571 for (Attribute attr : accessGroups) { 2572 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2573 if (!accessGroupRef) 2574 return op->emitOpError() 2575 << "expected '" << attr << "' to be a symbol reference"; 2576 StringAttr metadataName = accessGroupRef.getRootReference(); 2577 auto metadataOp = 2578 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2579 op->getParentOp(), metadataName); 2580 if (!metadataOp) 2581 return op->emitOpError() 2582 << "expected '" << attr << "' to reference a metadata op"; 2583 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2584 Operation *accessGroupOp = 2585 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2586 if (!accessGroupOp) 2587 return op->emitOpError() 2588 << "expected '" << attr << "' to reference an access_group op"; 2589 } 2590 } 2591 2592 Optional<NamedAttribute> loopOptions = 2593 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2594 if (loopOptions.hasValue() && 2595 !loopOptions->getValue().isa<LoopOptionsAttr>()) 2596 return op->emitOpError() 2597 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2598 << "' to be a `loopopts` attribute"; 2599 } 2600 2601 // If the data layout attribute is present, it must use the LLVM data layout 2602 // syntax. Try parsing it and report errors in case of failure. Users of this 2603 // attribute may assume it is well-formed and can pass it to the (asserting) 2604 // llvm::DataLayout constructor. 2605 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2606 return success(); 2607 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2608 return verifyDataLayoutString( 2609 stringAttr.getValue(), 2610 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2611 2612 return op->emitOpError() << "expected '" 2613 << LLVM::LLVMDialect::getDataLayoutAttrName() 2614 << "' to be a string attribute"; 2615 } 2616 2617 /// Verify LLVMIR function argument attributes. 2618 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2619 unsigned regionIdx, 2620 unsigned argIdx, 2621 NamedAttribute argAttr) { 2622 // Check that llvm.noalias is a unit attribute. 2623 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2624 !argAttr.getValue().isa<UnitAttr>()) 2625 return op->emitError() 2626 << "expected llvm.noalias argument attribute to be a unit attribute"; 2627 // Check that llvm.align is an integer attribute. 2628 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2629 !argAttr.getValue().isa<IntegerAttr>()) 2630 return op->emitError() 2631 << "llvm.align argument attribute of non integer type"; 2632 return success(); 2633 } 2634 2635 //===----------------------------------------------------------------------===// 2636 // Utility functions. 2637 //===----------------------------------------------------------------------===// 2638 2639 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2640 StringRef name, StringRef value, 2641 LLVM::Linkage linkage) { 2642 assert(builder.getInsertionBlock() && 2643 builder.getInsertionBlock()->getParentOp() && 2644 "expected builder to point to a block constrained in an op"); 2645 auto module = 2646 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2647 assert(module && "builder points to an op outside of a module"); 2648 2649 // Create the global at the entry of the module. 2650 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2651 MLIRContext *ctx = builder.getContext(); 2652 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2653 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2654 loc, type, /*isConstant=*/true, linkage, name, 2655 builder.getStringAttr(value), /*alignment=*/0); 2656 2657 // Get the pointer to the first character in the global string. 2658 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2659 Value cst0 = builder.create<LLVM::ConstantOp>( 2660 loc, IntegerType::get(ctx, 64), 2661 builder.getIntegerAttr(builder.getIndexType(), 0)); 2662 return builder.create<LLVM::GEPOp>( 2663 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2664 ValueRange{cst0, cst0}); 2665 } 2666 2667 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2668 return op->hasTrait<OpTrait::SymbolTable>() && 2669 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2670 } 2671 2672 static constexpr const FastmathFlags fastmathFlagsList[] = { 2673 // clang-format off 2674 FastmathFlags::nnan, 2675 FastmathFlags::ninf, 2676 FastmathFlags::nsz, 2677 FastmathFlags::arcp, 2678 FastmathFlags::contract, 2679 FastmathFlags::afn, 2680 FastmathFlags::reassoc, 2681 FastmathFlags::fast, 2682 // clang-format on 2683 }; 2684 2685 void FMFAttr::print(AsmPrinter &printer) const { 2686 printer << "<"; 2687 auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { 2688 return bitEnumContains(this->getFlags(), flag); 2689 }); 2690 llvm::interleaveComma(flags, printer, 2691 [&](auto flag) { printer << stringifyEnum(flag); }); 2692 printer << ">"; 2693 } 2694 2695 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2696 if (failed(parser.parseLess())) 2697 return {}; 2698 2699 FastmathFlags flags = {}; 2700 if (failed(parser.parseOptionalGreater())) { 2701 do { 2702 StringRef elemName; 2703 if (failed(parser.parseKeyword(&elemName))) 2704 return {}; 2705 2706 auto elem = symbolizeFastmathFlags(elemName); 2707 if (!elem) { 2708 parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2709 << elemName; 2710 return {}; 2711 } 2712 2713 flags = flags | *elem; 2714 } while (succeeded(parser.parseOptionalComma())); 2715 2716 if (failed(parser.parseGreater())) 2717 return {}; 2718 } 2719 2720 return FMFAttr::get(parser.getContext(), flags); 2721 } 2722 2723 void LinkageAttr::print(AsmPrinter &printer) const { 2724 printer << "<"; 2725 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2726 printer << stringifyEnum(getLinkage()); 2727 else 2728 printer << static_cast<uint64_t>(getLinkage()); 2729 printer << ">"; 2730 } 2731 2732 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2733 StringRef elemName; 2734 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2735 parser.parseGreater()) 2736 return {}; 2737 auto elem = linkage::symbolizeLinkage(elemName); 2738 if (!elem) { 2739 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2740 return {}; 2741 } 2742 Linkage linkage = *elem; 2743 return LinkageAttr::get(parser.getContext(), linkage); 2744 } 2745 2746 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2747 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2748 2749 template <typename T> 2750 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2751 Optional<T> value) { 2752 auto option = llvm::find_if( 2753 options, [tag](auto option) { return option.first == tag; }); 2754 if (option != options.end()) { 2755 if (value.hasValue()) 2756 option->second = *value; 2757 else 2758 options.erase(option); 2759 } else { 2760 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2761 } 2762 return *this; 2763 } 2764 2765 LoopOptionsAttrBuilder & 2766 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 2767 return setOption(LoopOptionCase::disable_licm, value); 2768 } 2769 2770 /// Set the `interleave_count` option to the provided value. If no value 2771 /// is provided the option is deleted. 2772 LoopOptionsAttrBuilder & 2773 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 2774 return setOption(LoopOptionCase::interleave_count, count); 2775 } 2776 2777 /// Set the `disable_unroll` option to the provided value. If no value 2778 /// is provided the option is deleted. 2779 LoopOptionsAttrBuilder & 2780 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 2781 return setOption(LoopOptionCase::disable_unroll, value); 2782 } 2783 2784 /// Set the `disable_pipeline` option to the provided value. If no value 2785 /// is provided the option is deleted. 2786 LoopOptionsAttrBuilder & 2787 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 2788 return setOption(LoopOptionCase::disable_pipeline, value); 2789 } 2790 2791 /// Set the `pipeline_initiation_interval` option to the provided value. 2792 /// If no value is provided the option is deleted. 2793 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 2794 Optional<uint64_t> count) { 2795 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 2796 } 2797 2798 template <typename T> 2799 static Optional<T> 2800 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 2801 LoopOptionCase option) { 2802 auto it = 2803 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 2804 return optionPair.first < option; 2805 }); 2806 if (it == options.end()) 2807 return {}; 2808 return static_cast<T>(it->second); 2809 } 2810 2811 Optional<bool> LoopOptionsAttr::disableUnroll() { 2812 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 2813 } 2814 2815 Optional<bool> LoopOptionsAttr::disableLICM() { 2816 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 2817 } 2818 2819 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 2820 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 2821 } 2822 2823 /// Build the LoopOptions Attribute from a sorted array of individual options. 2824 LoopOptionsAttr LoopOptionsAttr::get( 2825 MLIRContext *context, 2826 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 2827 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 2828 "LoopOptionsAttr ctor expects a sorted options array"); 2829 return Base::get(context, sortedOptions); 2830 } 2831 2832 /// Build the LoopOptions Attribute from a sorted array of individual options. 2833 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 2834 LoopOptionsAttrBuilder &optionBuilders) { 2835 llvm::sort(optionBuilders.options, llvm::less_first()); 2836 return Base::get(context, optionBuilders.options); 2837 } 2838 2839 void LoopOptionsAttr::print(AsmPrinter &printer) const { 2840 printer << "<"; 2841 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 2842 printer << stringifyEnum(option.first) << " = "; 2843 switch (option.first) { 2844 case LoopOptionCase::disable_licm: 2845 case LoopOptionCase::disable_unroll: 2846 case LoopOptionCase::disable_pipeline: 2847 printer << (option.second ? "true" : "false"); 2848 break; 2849 case LoopOptionCase::interleave_count: 2850 case LoopOptionCase::pipeline_initiation_interval: 2851 printer << option.second; 2852 break; 2853 } 2854 }); 2855 printer << ">"; 2856 } 2857 2858 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 2859 if (failed(parser.parseLess())) 2860 return {}; 2861 2862 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 2863 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 2864 do { 2865 StringRef optionName; 2866 if (parser.parseKeyword(&optionName)) 2867 return {}; 2868 2869 auto option = symbolizeLoopOptionCase(optionName); 2870 if (!option) { 2871 parser.emitError(parser.getNameLoc(), "unknown loop option: ") 2872 << optionName; 2873 return {}; 2874 } 2875 if (!seenOptions.insert(*option).second) { 2876 parser.emitError(parser.getNameLoc(), "loop option present twice"); 2877 return {}; 2878 } 2879 if (failed(parser.parseEqual())) 2880 return {}; 2881 2882 int64_t value; 2883 switch (*option) { 2884 case LoopOptionCase::disable_licm: 2885 case LoopOptionCase::disable_unroll: 2886 case LoopOptionCase::disable_pipeline: 2887 if (succeeded(parser.parseOptionalKeyword("true"))) 2888 value = 1; 2889 else if (succeeded(parser.parseOptionalKeyword("false"))) 2890 value = 0; 2891 else { 2892 parser.emitError(parser.getNameLoc(), 2893 "expected boolean value 'true' or 'false'"); 2894 return {}; 2895 } 2896 break; 2897 case LoopOptionCase::interleave_count: 2898 case LoopOptionCase::pipeline_initiation_interval: 2899 if (failed(parser.parseInteger(value))) { 2900 parser.emitError(parser.getNameLoc(), "expected integer value"); 2901 return {}; 2902 } 2903 break; 2904 } 2905 options.push_back(std::make_pair(*option, value)); 2906 } while (succeeded(parser.parseOptionalComma())); 2907 if (failed(parser.parseGreater())) 2908 return {}; 2909 2910 llvm::sort(options, llvm::less_first()); 2911 return get(parser.getContext(), options); 2912 } 2913