1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Error.h" 33 #include "llvm/Support/Mutex.h" 34 #include "llvm/Support/SourceMgr.h" 35 36 #include <numeric> 37 38 using namespace mlir; 39 using namespace mlir::LLVM; 40 using mlir::LLVM::cconv::getMaxEnumValForCConv; 41 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 42 43 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 44 45 static constexpr const char kVolatileAttrName[] = "volatile_"; 46 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 47 static constexpr const char kElemTypeAttrName[] = "elem_type"; 48 49 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 50 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 51 #define GET_ATTRDEF_CLASSES 52 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 53 54 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 55 SmallVector<NamedAttribute, 8> filteredAttrs( 56 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 57 if (attr.getName() == "fastmathFlags") { 58 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 59 return defAttr != attr.getValue(); 60 } 61 return true; 62 })); 63 return filteredAttrs; 64 } 65 66 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 67 NamedAttrList &result) { 68 return parser.parseOptionalAttrDict(result); 69 } 70 71 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 72 DictionaryAttr attrs) { 73 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 74 } 75 76 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 77 /// fully defined llvm.func. 78 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 79 Operation *op, 80 SymbolTableCollection &symbolTable) { 81 StringRef name = symbol.getValue(); 82 auto func = 83 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 84 if (!func) 85 return op->emitOpError("'") 86 << name << "' does not reference a valid LLVM function"; 87 if (func.isExternal()) 88 return op->emitOpError("'") << name << "' does not have a definition"; 89 return success(); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // Printing/parsing for LLVM::CmpOp. 94 //===----------------------------------------------------------------------===// 95 96 void ICmpOp::print(OpAsmPrinter &p) { 97 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 98 << ", " << getOperand(1); 99 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 100 p << " : " << getLhs().getType(); 101 } 102 103 void FCmpOp::print(OpAsmPrinter &p) { 104 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 105 << ", " << getOperand(1); 106 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 107 p << " : " << getLhs().getType(); 108 } 109 110 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 111 // attribute-dict? `:` type 112 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 113 // attribute-dict? `:` type 114 template <typename CmpPredicateType> 115 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 116 Builder &builder = parser.getBuilder(); 117 118 StringAttr predicateAttr; 119 OpAsmParser::UnresolvedOperand lhs, rhs; 120 Type type; 121 SMLoc predicateLoc, trailingTypeLoc; 122 if (parser.getCurrentLocation(&predicateLoc) || 123 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 124 parser.parseOperand(lhs) || parser.parseComma() || 125 parser.parseOperand(rhs) || 126 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 127 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 128 parser.resolveOperand(lhs, type, result.operands) || 129 parser.resolveOperand(rhs, type, result.operands)) 130 return failure(); 131 132 // Replace the string attribute `predicate` with an integer attribute. 133 int64_t predicateValue = 0; 134 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 135 Optional<ICmpPredicate> predicate = 136 symbolizeICmpPredicate(predicateAttr.getValue()); 137 if (!predicate) 138 return parser.emitError(predicateLoc) 139 << "'" << predicateAttr.getValue() 140 << "' is an incorrect value of the 'predicate' attribute"; 141 predicateValue = static_cast<int64_t>(predicate.getValue()); 142 } else { 143 Optional<FCmpPredicate> predicate = 144 symbolizeFCmpPredicate(predicateAttr.getValue()); 145 if (!predicate) 146 return parser.emitError(predicateLoc) 147 << "'" << predicateAttr.getValue() 148 << "' is an incorrect value of the 'predicate' attribute"; 149 predicateValue = static_cast<int64_t>(predicate.getValue()); 150 } 151 152 result.attributes.set("predicate", 153 parser.getBuilder().getI64IntegerAttr(predicateValue)); 154 155 // The result type is either i1 or a vector type <? x i1> if the inputs are 156 // vectors. 157 Type resultType = IntegerType::get(builder.getContext(), 1); 158 if (!isCompatibleType(type)) 159 return parser.emitError(trailingTypeLoc, 160 "expected LLVM dialect-compatible type"); 161 if (LLVM::isCompatibleVectorType(type)) { 162 if (LLVM::isScalableVectorType(type)) { 163 resultType = LLVM::getVectorType( 164 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 165 /*isScalable=*/true); 166 } else { 167 resultType = LLVM::getVectorType( 168 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 169 /*isScalable=*/false); 170 } 171 } 172 173 result.addTypes({resultType}); 174 return success(); 175 } 176 177 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 178 return parseCmpOp<ICmpPredicate>(parser, result); 179 } 180 181 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 182 return parseCmpOp<FCmpPredicate>(parser, result); 183 } 184 185 //===----------------------------------------------------------------------===// 186 // Printing, parsing and verification for LLVM::AllocaOp. 187 //===----------------------------------------------------------------------===// 188 189 void AllocaOp::print(OpAsmPrinter &p) { 190 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 191 if (!elemTy) 192 elemTy = *getElemType(); 193 194 auto funcTy = 195 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 196 197 p << ' ' << getArraySize() << " x " << elemTy; 198 if (getAlignment().hasValue() && *getAlignment() != 0) 199 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName}); 200 else 201 p.printOptionalAttrDict((*this)->getAttrs(), 202 {"alignment", kElemTypeAttrName}); 203 p << " : " << funcTy; 204 } 205 206 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 207 // `:` type `,` type 208 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 209 OpAsmParser::UnresolvedOperand arraySize; 210 Type type, elemType; 211 SMLoc trailingTypeLoc; 212 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 213 parser.parseType(elemType) || 214 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 215 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 216 return failure(); 217 218 Optional<NamedAttribute> alignmentAttr = 219 result.attributes.getNamed("alignment"); 220 if (alignmentAttr.hasValue()) { 221 auto alignmentInt = 222 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 223 if (!alignmentInt) 224 return parser.emitError(parser.getNameLoc(), 225 "expected integer alignment"); 226 if (alignmentInt.getValue().isNullValue()) 227 result.attributes.erase("alignment"); 228 } 229 230 // Extract the result type from the trailing function type. 231 auto funcType = type.dyn_cast<FunctionType>(); 232 if (!funcType || funcType.getNumInputs() != 1 || 233 funcType.getNumResults() != 1) 234 return parser.emitError( 235 trailingTypeLoc, 236 "expected trailing function type with one argument and one result"); 237 238 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 239 return failure(); 240 241 Type resultType = funcType.getResult(0); 242 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) { 243 if (ptrResultType.isOpaque()) 244 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 245 } 246 247 result.addTypes({funcType.getResult(0)}); 248 return success(); 249 } 250 251 /// Checks that the elemental type is present in either the pointer type or 252 /// the attribute, but not both. 253 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, 254 Optional<Type> ptrElementType) { 255 if (ptrType.isOpaque() && !ptrElementType.hasValue()) { 256 return op->emitOpError() << "expected '" << kElemTypeAttrName 257 << "' attribute if opaque pointer type is used"; 258 } 259 if (!ptrType.isOpaque() && ptrElementType.hasValue()) { 260 return op->emitOpError() 261 << "unexpected '" << kElemTypeAttrName 262 << "' attribute when non-opaque pointer type is used"; 263 } 264 return success(); 265 } 266 267 LogicalResult AllocaOp::verify() { 268 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(), 269 getElemType()); 270 } 271 272 //===----------------------------------------------------------------------===// 273 // LLVM::BrOp 274 //===----------------------------------------------------------------------===// 275 276 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 277 assert(index == 0 && "invalid successor index"); 278 return SuccessorOperands(getDestOperandsMutable()); 279 } 280 281 //===----------------------------------------------------------------------===// 282 // LLVM::CondBrOp 283 //===----------------------------------------------------------------------===// 284 285 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 286 assert(index < getNumSuccessors() && "invalid successor index"); 287 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 288 : getFalseDestOperandsMutable()); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // LLVM::SwitchOp 293 //===----------------------------------------------------------------------===// 294 295 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 296 Block *defaultDestination, ValueRange defaultOperands, 297 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 298 ArrayRef<ValueRange> caseOperands, 299 ArrayRef<int32_t> branchWeights) { 300 ElementsAttr caseValuesAttr; 301 if (!caseValues.empty()) 302 caseValuesAttr = builder.getI32VectorAttr(caseValues); 303 304 ElementsAttr weightsAttr; 305 if (!branchWeights.empty()) 306 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 307 308 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 309 weightsAttr, defaultDestination, caseDestinations); 310 } 311 312 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 313 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 314 static ParseResult parseSwitchOpCases( 315 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 316 SmallVectorImpl<Block *> &caseDestinations, 317 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 318 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 319 SmallVector<APInt> values; 320 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 321 do { 322 int64_t value = 0; 323 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 324 if (values.empty() && !integerParseResult.hasValue()) 325 return success(); 326 327 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 328 return failure(); 329 values.push_back(APInt(bitWidth, value)); 330 331 Block *destination; 332 SmallVector<OpAsmParser::UnresolvedOperand> operands; 333 SmallVector<Type> operandTypes; 334 if (parser.parseColon() || parser.parseSuccessor(destination)) 335 return failure(); 336 if (!parser.parseOptionalLParen()) { 337 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 338 /*allowResultNumber=*/false) || 339 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 340 return failure(); 341 } 342 caseDestinations.push_back(destination); 343 caseOperands.emplace_back(operands); 344 caseOperandTypes.emplace_back(operandTypes); 345 } while (!parser.parseOptionalComma()); 346 347 ShapedType caseValueType = 348 VectorType::get(static_cast<int64_t>(values.size()), flagType); 349 caseValues = DenseIntElementsAttr::get(caseValueType, values); 350 return success(); 351 } 352 353 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 354 ElementsAttr caseValues, 355 SuccessorRange caseDestinations, 356 OperandRangeRange caseOperands, 357 const TypeRangeRange &caseOperandTypes) { 358 if (!caseValues) 359 return; 360 361 size_t index = 0; 362 llvm::interleave( 363 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 364 [&](auto i) { 365 p << " "; 366 p << std::get<0>(i).getLimitedValue(); 367 p << ": "; 368 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 369 }, 370 [&] { 371 p << ','; 372 p.printNewline(); 373 }); 374 p.printNewline(); 375 } 376 377 LogicalResult SwitchOp::verify() { 378 if ((!getCaseValues() && !getCaseDestinations().empty()) || 379 (getCaseValues() && 380 getCaseValues()->size() != 381 static_cast<int64_t>(getCaseDestinations().size()))) 382 return emitOpError("expects number of case values to match number of " 383 "case destinations"); 384 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 385 return emitError("expects number of branch weights to match number of " 386 "successors: ") 387 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 388 return success(); 389 } 390 391 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 392 assert(index < getNumSuccessors() && "invalid successor index"); 393 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 394 : getCaseOperandsMutable(index - 1)); 395 } 396 397 //===----------------------------------------------------------------------===// 398 // Code for LLVM::GEPOp. 399 //===----------------------------------------------------------------------===// 400 401 constexpr int GEPOp::kDynamicIndex; 402 403 namespace { 404 /// Base class for llvm::Error related to GEP index. 405 class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> { 406 protected: 407 unsigned indexPos; 408 409 public: 410 static char ID; 411 412 std::error_code convertToErrorCode() const override { 413 return llvm::inconvertibleErrorCode(); 414 } 415 416 explicit GEPIndexError(unsigned pos) : indexPos(pos) {} 417 }; 418 419 /// llvm::Error for out-of-bound GEP index. 420 struct GEPIndexOutOfBoundError 421 : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> { 422 static char ID; 423 424 using ErrorInfo::ErrorInfo; 425 426 void log(llvm::raw_ostream &os) const override { 427 os << "index " << indexPos << " indexing a struct is out of bounds"; 428 } 429 }; 430 431 /// llvm::Error for non-static GEP index indexing a struct. 432 struct GEPStaticIndexError 433 : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> { 434 static char ID; 435 436 using ErrorInfo::ErrorInfo; 437 438 void log(llvm::raw_ostream &os) const override { 439 os << "expected index " << indexPos << " indexing a struct " 440 << "to be constant"; 441 } 442 }; 443 } // end anonymous namespace 444 445 char GEPIndexError::ID = 0; 446 char GEPIndexOutOfBoundError::ID = 0; 447 char GEPStaticIndexError::ID = 0; 448 449 /// For the given `structIndices` and `indices`, check if they're complied 450 /// with `baseGEPType`, especially check against LLVMStructTypes nested within, 451 /// and refine/promote struct index from `indices` to `updatedStructIndices` 452 /// if the latter argument is not null. 453 static llvm::Error 454 recordStructIndices(Type baseGEPType, unsigned indexPos, 455 ArrayRef<int32_t> structIndices, ValueRange indices, 456 SmallVectorImpl<int32_t> *updatedStructIndices, 457 SmallVectorImpl<Value> *remainingIndices) { 458 if (indexPos >= structIndices.size()) 459 // Stop searching 460 return llvm::Error::success(); 461 462 int32_t gepIndex = structIndices[indexPos]; 463 bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; 464 465 unsigned dynamicIndexPos = indexPos; 466 if (!isStaticIndex) 467 dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), 468 LLVM::GEPOp::kDynamicIndex) - 1; 469 470 return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType) 471 .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error { 472 // We don't always want to refine the index (e.g. when performing 473 // verification), so we only refine when updatedStructIndices is not 474 // null. 475 if (!isStaticIndex && updatedStructIndices) { 476 // Try to refine. 477 APInt staticIndexValue; 478 isStaticIndex = matchPattern(indices[dynamicIndexPos], 479 m_ConstantInt(&staticIndexValue)); 480 if (isStaticIndex) { 481 assert(staticIndexValue.getBitWidth() <= 64 && 482 llvm::isInt<32>(staticIndexValue.getLimitedValue()) && 483 "struct index can't fit within int32_t"); 484 gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 485 } 486 } 487 if (!isStaticIndex) 488 return llvm::make_error<GEPStaticIndexError>(indexPos); 489 490 ArrayRef<Type> elementTypes = structType.getBody(); 491 if (gepIndex < 0 || 492 static_cast<size_t>(gepIndex) >= elementTypes.size()) 493 return llvm::make_error<GEPIndexOutOfBoundError>(indexPos); 494 495 if (updatedStructIndices) 496 (*updatedStructIndices)[indexPos] = gepIndex; 497 498 // Instead of recusively going into every children types, we only 499 // dive into the one indexed by gepIndex. 500 return recordStructIndices(elementTypes[gepIndex], indexPos + 1, 501 structIndices, indices, updatedStructIndices, 502 remainingIndices); 503 }) 504 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 505 LLVMArrayType>([&](auto containerType) -> llvm::Error { 506 // Currently we don't refine non-struct index even if it's static. 507 if (remainingIndices) 508 remainingIndices->push_back(indices[dynamicIndexPos]); 509 return recordStructIndices(containerType.getElementType(), indexPos + 1, 510 structIndices, indices, updatedStructIndices, 511 remainingIndices); 512 }) 513 .Default( 514 [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); 515 } 516 517 /// Driver function around `recordStructIndices`. Note that we always check 518 /// from the second GEP index since the first one is always dynamic. 519 static llvm::Error 520 findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices, 521 ValueRange indices, 522 SmallVectorImpl<int32_t> *updatedStructIndices = nullptr, 523 SmallVectorImpl<Value> *remainingIndices = nullptr) { 524 if (remainingIndices) 525 // The first GEP index is always dynamic. 526 remainingIndices->push_back(indices[0]); 527 return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, 528 indices, updatedStructIndices, remainingIndices); 529 } 530 531 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 532 Value basePtr, ValueRange operands, 533 ArrayRef<NamedAttribute> attributes) { 534 build(builder, result, resultType, basePtr, operands, 535 SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes); 536 } 537 538 /// Returns the elemental type of any LLVM-compatible vector type or self. 539 static Type extractVectorElementType(Type type) { 540 if (auto vectorType = type.dyn_cast<VectorType>()) 541 return vectorType.getElementType(); 542 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 543 return scalableVectorType.getElementType(); 544 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 545 return fixedVectorType.getElementType(); 546 return type; 547 } 548 549 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 550 Value basePtr, ValueRange indices, 551 ArrayRef<int32_t> structIndices, 552 ArrayRef<NamedAttribute> attributes) { 553 auto ptrType = 554 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>(); 555 assert(!ptrType.isOpaque() && 556 "expected non-opaque pointer, provide elementType explicitly when " 557 "opaque pointers are used"); 558 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, 559 structIndices, attributes); 560 } 561 562 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 563 Type elementType, Value basePtr, ValueRange indices, 564 ArrayRef<int32_t> structIndices, 565 ArrayRef<NamedAttribute> attributes) { 566 SmallVector<Value> remainingIndices; 567 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 568 structIndices.end()); 569 if (llvm::Error err = 570 findStructIndices(elementType, structIndices, indices, 571 &updatedStructIndices, &remainingIndices)) 572 llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); 573 574 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 575 updatedStructIndices, kDynamicIndex)) && 576 "expected as many index operands as dynamic index attr elements"); 577 578 result.addTypes(resultType); 579 result.addAttributes(attributes); 580 result.addAttribute("structIndices", 581 builder.getI32TensorAttr(updatedStructIndices)); 582 if (extractVectorElementType(basePtr.getType()) 583 .cast<LLVMPointerType>() 584 .isOpaque()) 585 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 586 result.addOperands(basePtr); 587 result.addOperands(remainingIndices); 588 } 589 590 static ParseResult 591 parseGEPIndices(OpAsmParser &parser, 592 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 593 DenseIntElementsAttr &structIndices) { 594 SmallVector<int32_t> constantIndices; 595 596 auto idxParser = [&]() -> ParseResult { 597 int32_t constantIndex; 598 OptionalParseResult parsedInteger = 599 parser.parseOptionalInteger(constantIndex); 600 if (parsedInteger.hasValue()) { 601 if (failed(parsedInteger.getValue())) 602 return failure(); 603 constantIndices.push_back(constantIndex); 604 return success(); 605 } 606 607 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 608 return parser.parseOperand(indices.emplace_back()); 609 }; 610 if (parser.parseCommaSeparatedList(idxParser)) 611 return failure(); 612 613 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 614 return success(); 615 } 616 617 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 618 OperandRange indices, 619 DenseIntElementsAttr structIndices) { 620 unsigned operandIdx = 0; 621 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 622 [&](int32_t cst) { 623 if (cst == LLVM::GEPOp::kDynamicIndex) 624 printer.printOperand(indices[operandIdx++]); 625 else 626 printer << cst; 627 }); 628 } 629 630 LogicalResult LLVM::GEPOp::verify() { 631 if (failed(verifyOpaquePtr( 632 getOperation(), 633 extractVectorElementType(getType()).cast<LLVMPointerType>(), 634 getElemType()))) 635 return failure(); 636 637 auto structIndexRange = getStructIndices().getValues<int32_t>(); 638 // structIndexRange is a kind of iterator, which cannot be converted 639 // to ArrayRef directly. 640 SmallVector<int32_t> structIndices(structIndexRange.size()); 641 for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size())) 642 structIndices[i] = structIndexRange[i]; 643 if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices, 644 getIndices())) 645 return emitOpError() << llvm::toString(std::move(err)); 646 647 return success(); 648 } 649 650 Type LLVM::GEPOp::getSourceElementType() { 651 if (Optional<Type> elemType = getElemType()) 652 return *elemType; 653 654 return extractVectorElementType(getBase().getType()) 655 .cast<LLVMPointerType>() 656 .getElementType(); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // Builder, printer and parser for for LLVM::LoadOp. 661 //===----------------------------------------------------------------------===// 662 663 LogicalResult verifySymbolAttribute( 664 Operation *op, StringRef attributeName, 665 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 666 verifySymbolType) { 667 if (Attribute attribute = op->getAttr(attributeName)) { 668 // The attribute is already verified to be a symbol ref array attribute via 669 // a constraint in the operation definition. 670 for (SymbolRefAttr symbolRef : 671 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 672 StringAttr metadataName = symbolRef.getRootReference(); 673 StringAttr symbolName = symbolRef.getLeafReference(); 674 // We want @metadata::@symbol, not just @symbol 675 if (metadataName == symbolName) { 676 return op->emitOpError() << "expected '" << symbolRef 677 << "' to specify a fully qualified reference"; 678 } 679 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 680 op->getParentOp(), metadataName); 681 if (!metadataOp) 682 return op->emitOpError() 683 << "expected '" << symbolRef << "' to reference a metadata op"; 684 Operation *symbolOp = 685 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 686 if (!symbolOp) 687 return op->emitOpError() 688 << "expected '" << symbolRef << "' to be a valid reference"; 689 if (failed(verifySymbolType(symbolOp, symbolRef))) { 690 return failure(); 691 } 692 } 693 } 694 return success(); 695 } 696 697 // Verifies that metadata ops are wired up properly. 698 template <typename OpTy> 699 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 700 auto verifySymbolType = [op](Operation *symbolOp, 701 SymbolRefAttr symbolRef) -> LogicalResult { 702 if (!isa<OpTy>(symbolOp)) { 703 return op->emitOpError() 704 << "expected '" << symbolRef << "' to resolve to a " 705 << OpTy::getOperationName(); 706 } 707 return success(); 708 }; 709 710 return verifySymbolAttribute(op, attributeName, verifySymbolType); 711 } 712 713 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 714 // access_groups 715 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 716 op, LLVMDialect::getAccessGroupsAttrName()))) 717 return failure(); 718 719 // alias_scopes 720 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 721 op, LLVMDialect::getAliasScopesAttrName()))) 722 return failure(); 723 724 // noalias_scopes 725 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 726 op, LLVMDialect::getNoAliasScopesAttrName()))) 727 return failure(); 728 729 return success(); 730 } 731 732 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 733 734 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 735 Value addr, unsigned alignment, bool isVolatile, 736 bool isNonTemporal) { 737 result.addOperands(addr); 738 result.addTypes(t); 739 if (isVolatile) 740 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 741 if (isNonTemporal) 742 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 743 if (alignment != 0) 744 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 745 } 746 747 void LoadOp::print(OpAsmPrinter &p) { 748 p << ' '; 749 if (getVolatile_()) 750 p << "volatile "; 751 p << getAddr(); 752 p.printOptionalAttrDict((*this)->getAttrs(), 753 {kVolatileAttrName, kElemTypeAttrName}); 754 p << " : " << getAddr().getType(); 755 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 756 p << " -> " << getType(); 757 } 758 759 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 760 // the resulting type if any, null type if opaque pointers are used, and None 761 // if the given type is not the pointer type. 762 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type, 763 SMLoc trailingTypeLoc) { 764 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 765 if (!llvmTy) { 766 parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); 767 return llvm::None; 768 } 769 return llvmTy.getElementType(); 770 } 771 772 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 773 // (`->` type)? 774 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 775 OpAsmParser::UnresolvedOperand addr; 776 Type type; 777 SMLoc trailingTypeLoc; 778 779 if (succeeded(parser.parseOptionalKeyword("volatile"))) 780 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 781 782 if (parser.parseOperand(addr) || 783 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 784 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 785 parser.resolveOperand(addr, type, result.operands)) 786 return failure(); 787 788 Optional<Type> elemTy = 789 getLoadStoreElementType(parser, type, trailingTypeLoc); 790 if (!elemTy) 791 return failure(); 792 if (*elemTy) { 793 result.addTypes(*elemTy); 794 return success(); 795 } 796 797 Type trailingType; 798 if (parser.parseArrow() || parser.parseType(trailingType)) 799 return failure(); 800 result.addTypes(trailingType); 801 return success(); 802 } 803 804 //===----------------------------------------------------------------------===// 805 // Builder, printer and parser for LLVM::StoreOp. 806 //===----------------------------------------------------------------------===// 807 808 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 809 810 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 811 Value addr, unsigned alignment, bool isVolatile, 812 bool isNonTemporal) { 813 result.addOperands({value, addr}); 814 result.addTypes({}); 815 if (isVolatile) 816 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 817 if (isNonTemporal) 818 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 819 if (alignment != 0) 820 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 821 } 822 823 void StoreOp::print(OpAsmPrinter &p) { 824 p << ' '; 825 if (getVolatile_()) 826 p << "volatile "; 827 p << getValue() << ", " << getAddr(); 828 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 829 p << " : "; 830 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 831 p << getValue().getType() << ", "; 832 p << getAddr().getType(); 833 } 834 835 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 836 // attribute-dict? `:` type (`,` type)? 837 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 838 OpAsmParser::UnresolvedOperand addr, value; 839 Type type; 840 SMLoc trailingTypeLoc; 841 842 if (succeeded(parser.parseOptionalKeyword("volatile"))) 843 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 844 845 if (parser.parseOperand(value) || parser.parseComma() || 846 parser.parseOperand(addr) || 847 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 848 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 849 return failure(); 850 851 Type operandType; 852 if (succeeded(parser.parseOptionalComma())) { 853 operandType = type; 854 if (parser.parseType(type)) 855 return failure(); 856 } else { 857 Optional<Type> maybeOperandType = 858 getLoadStoreElementType(parser, type, trailingTypeLoc); 859 if (!maybeOperandType) 860 return failure(); 861 operandType = *maybeOperandType; 862 } 863 864 if (parser.resolveOperand(value, operandType, result.operands) || 865 parser.resolveOperand(addr, type, result.operands)) 866 return failure(); 867 868 return success(); 869 } 870 871 ///===---------------------------------------------------------------------===// 872 /// LLVM::InvokeOp 873 ///===---------------------------------------------------------------------===// 874 875 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 876 assert(index < getNumSuccessors() && "invalid successor index"); 877 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 878 : getUnwindDestOperandsMutable()); 879 } 880 881 LogicalResult InvokeOp::verify() { 882 if (getNumResults() > 1) 883 return emitOpError("must have 0 or 1 result"); 884 885 Block *unwindDest = getUnwindDest(); 886 if (unwindDest->empty()) 887 return emitError("must have at least one operation in unwind destination"); 888 889 // In unwind destination, first operation must be LandingpadOp 890 if (!isa<LandingpadOp>(unwindDest->front())) 891 return emitError("first operation in unwind destination should be a " 892 "llvm.landingpad operation"); 893 894 return success(); 895 } 896 897 void InvokeOp::print(OpAsmPrinter &p) { 898 auto callee = getCallee(); 899 bool isDirect = callee.hasValue(); 900 901 p << ' '; 902 903 // Either function name or pointer 904 if (isDirect) 905 p.printSymbolName(callee.getValue()); 906 else 907 p << getOperand(0); 908 909 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 910 p << " to "; 911 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 912 p << " unwind "; 913 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 914 915 p.printOptionalAttrDict((*this)->getAttrs(), 916 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 917 p << " : "; 918 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 919 getResultTypes()); 920 } 921 922 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 923 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 924 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 925 /// attribute-dict? `:` function-type 926 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 927 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 928 FunctionType funcType; 929 SymbolRefAttr funcAttr; 930 SMLoc trailingTypeLoc; 931 Block *normalDest, *unwindDest; 932 SmallVector<Value, 4> normalOperands, unwindOperands; 933 Builder &builder = parser.getBuilder(); 934 935 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 936 // case of an indirect call, there will be 1 operand before `(`. In case of a 937 // direct call, there will be no operands and the parser will stop at the 938 // function identifier without complaining. 939 if (parser.parseOperandList(operands)) 940 return failure(); 941 bool isDirect = operands.empty(); 942 943 // Optionally parse a function identifier. 944 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 945 return failure(); 946 947 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 948 parser.parseKeyword("to") || 949 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 950 parser.parseKeyword("unwind") || 951 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 952 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 953 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 954 return failure(); 955 956 if (isDirect) { 957 // Make sure types match. 958 if (parser.resolveOperands(operands, funcType.getInputs(), 959 parser.getNameLoc(), result.operands)) 960 return failure(); 961 result.addTypes(funcType.getResults()); 962 } else { 963 // Construct the LLVM IR Dialect function type that the first operand 964 // should match. 965 if (funcType.getNumResults() > 1) 966 return parser.emitError(trailingTypeLoc, 967 "expected function with 0 or 1 result"); 968 969 Type llvmResultType; 970 if (funcType.getNumResults() == 0) { 971 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 972 } else { 973 llvmResultType = funcType.getResult(0); 974 if (!isCompatibleType(llvmResultType)) 975 return parser.emitError(trailingTypeLoc, 976 "expected result to have LLVM type"); 977 } 978 979 SmallVector<Type, 8> argTypes; 980 argTypes.reserve(funcType.getNumInputs()); 981 for (Type ty : funcType.getInputs()) { 982 if (isCompatibleType(ty)) 983 argTypes.push_back(ty); 984 else 985 return parser.emitError(trailingTypeLoc, 986 "expected LLVM types as inputs"); 987 } 988 989 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 990 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 991 992 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 993 994 // Make sure that the first operand (indirect callee) matches the wrapped 995 // LLVM IR function type, and that the types of the other call operands 996 // match the types of the function arguments. 997 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 998 parser.resolveOperands(funcArguments, funcType.getInputs(), 999 parser.getNameLoc(), result.operands)) 1000 return failure(); 1001 1002 result.addTypes(llvmResultType); 1003 } 1004 result.addSuccessors({normalDest, unwindDest}); 1005 result.addOperands(normalOperands); 1006 result.addOperands(unwindOperands); 1007 1008 result.addAttribute( 1009 InvokeOp::getOperandSegmentSizeAttr(), 1010 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 1011 static_cast<int32_t>(normalOperands.size()), 1012 static_cast<int32_t>(unwindOperands.size())})); 1013 return success(); 1014 } 1015 1016 ///===----------------------------------------------------------------------===// 1017 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 1018 ///===----------------------------------------------------------------------===// 1019 1020 LogicalResult LandingpadOp::verify() { 1021 Value value; 1022 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 1023 if (!func.getPersonality().hasValue()) 1024 return emitError( 1025 "llvm.landingpad needs to be in a function with a personality"); 1026 } 1027 1028 if (!getCleanup() && getOperands().empty()) 1029 return emitError("landingpad instruction expects at least one clause or " 1030 "cleanup attribute"); 1031 1032 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 1033 value = getOperand(idx); 1034 bool isFilter = value.getType().isa<LLVMArrayType>(); 1035 if (isFilter) { 1036 // FIXME: Verify filter clauses when arrays are appropriately handled 1037 } else { 1038 // catch - global addresses only. 1039 // Bitcast ops should have global addresses as their args. 1040 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 1041 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 1042 continue; 1043 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 1044 << "global addresses expected as operand to " 1045 "bitcast used in clauses for landingpad"; 1046 } 1047 // NullOp and AddressOfOp allowed 1048 if (value.getDefiningOp<NullOp>()) 1049 continue; 1050 if (value.getDefiningOp<AddressOfOp>()) 1051 continue; 1052 return emitError("clause #") 1053 << idx << " is not a known constant - null, addressof, bitcast"; 1054 } 1055 } 1056 return success(); 1057 } 1058 1059 void LandingpadOp::print(OpAsmPrinter &p) { 1060 p << (getCleanup() ? " cleanup " : " "); 1061 1062 // Clauses 1063 for (auto value : getOperands()) { 1064 // Similar to llvm - if clause is an array type then it is filter 1065 // clause else catch clause 1066 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 1067 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1068 << value.getType() << ") "; 1069 } 1070 1071 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1072 1073 p << ": " << getType(); 1074 } 1075 1076 /// <operation> ::= `llvm.landingpad` `cleanup`? 1077 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1078 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1079 // Check for cleanup 1080 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1081 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1082 1083 // Parse clauses with types 1084 while (succeeded(parser.parseOptionalLParen()) && 1085 (succeeded(parser.parseOptionalKeyword("filter")) || 1086 succeeded(parser.parseOptionalKeyword("catch")))) { 1087 OpAsmParser::UnresolvedOperand operand; 1088 Type ty; 1089 if (parser.parseOperand(operand) || parser.parseColon() || 1090 parser.parseType(ty) || 1091 parser.resolveOperand(operand, ty, result.operands) || 1092 parser.parseRParen()) 1093 return failure(); 1094 } 1095 1096 Type type; 1097 if (parser.parseColon() || parser.parseType(type)) 1098 return failure(); 1099 1100 result.addTypes(type); 1101 return success(); 1102 } 1103 1104 //===----------------------------------------------------------------------===// 1105 // Verifying/Printing/parsing for LLVM::CallOp. 1106 //===----------------------------------------------------------------------===// 1107 1108 LogicalResult CallOp::verify() { 1109 if (getNumResults() > 1) 1110 return emitOpError("must have 0 or 1 result"); 1111 1112 // Type for the callee, we'll get it differently depending if it is a direct 1113 // or indirect call. 1114 Type fnType; 1115 1116 bool isIndirect = false; 1117 1118 // If this is an indirect call, the callee attribute is missing. 1119 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1120 if (!calleeName) { 1121 isIndirect = true; 1122 if (!getNumOperands()) 1123 return emitOpError( 1124 "must have either a `callee` attribute or at least an operand"); 1125 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 1126 if (!ptrType) 1127 return emitOpError("indirect call expects a pointer as callee: ") 1128 << ptrType; 1129 fnType = ptrType.getElementType(); 1130 } else { 1131 Operation *callee = 1132 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1133 if (!callee) 1134 return emitOpError() 1135 << "'" << calleeName.getValue() 1136 << "' does not reference a symbol in the current scope"; 1137 auto fn = dyn_cast<LLVMFuncOp>(callee); 1138 if (!fn) 1139 return emitOpError() << "'" << calleeName.getValue() 1140 << "' does not reference a valid LLVM function"; 1141 1142 fnType = fn.getFunctionType(); 1143 } 1144 1145 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1146 if (!funcType) 1147 return emitOpError("callee does not have a functional type: ") << fnType; 1148 1149 // Verify that the operand and result types match the callee. 1150 1151 if (!funcType.isVarArg() && 1152 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1153 return emitOpError() << "incorrect number of operands (" 1154 << (getNumOperands() - isIndirect) 1155 << ") for callee (expecting: " 1156 << funcType.getNumParams() << ")"; 1157 1158 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1159 return emitOpError() << "incorrect number of operands (" 1160 << (getNumOperands() - isIndirect) 1161 << ") for varargs callee (expecting at least: " 1162 << funcType.getNumParams() << ")"; 1163 1164 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1165 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1166 return emitOpError() << "operand type mismatch for operand " << i << ": " 1167 << getOperand(i + isIndirect).getType() 1168 << " != " << funcType.getParamType(i); 1169 1170 if (getNumResults() == 0 && 1171 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1172 return emitOpError() << "expected function call to produce a value"; 1173 1174 if (getNumResults() != 0 && 1175 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1176 return emitOpError() 1177 << "calling function with void result must not produce values"; 1178 1179 if (getNumResults() > 1) 1180 return emitOpError() 1181 << "expected LLVM function call to produce 0 or 1 result"; 1182 1183 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1184 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1185 << " != " << funcType.getReturnType(); 1186 1187 return success(); 1188 } 1189 1190 void CallOp::print(OpAsmPrinter &p) { 1191 auto callee = getCallee(); 1192 bool isDirect = callee.hasValue(); 1193 1194 // Print the direct callee if present as a function attribute, or an indirect 1195 // callee (first operand) otherwise. 1196 p << ' '; 1197 if (isDirect) 1198 p.printSymbolName(callee.getValue()); 1199 else 1200 p << getOperand(0); 1201 1202 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1203 p << '(' << args << ')'; 1204 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1205 1206 // Reconstruct the function MLIR function type from operand and result types. 1207 p << " : "; 1208 p.printFunctionalType(args.getTypes(), getResultTypes()); 1209 } 1210 1211 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1212 // attribute-dict? `:` function-type 1213 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1214 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1215 Type type; 1216 SymbolRefAttr funcAttr; 1217 SMLoc trailingTypeLoc; 1218 1219 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1220 // case of an indirect call, there will be 1 operand before `(`. In case of a 1221 // direct call, there will be no operands and the parser will stop at the 1222 // function identifier without complaining. 1223 if (parser.parseOperandList(operands)) 1224 return failure(); 1225 bool isDirect = operands.empty(); 1226 1227 // Optionally parse a function identifier. 1228 if (isDirect) 1229 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1230 return failure(); 1231 1232 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1233 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1234 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1235 return failure(); 1236 1237 auto funcType = type.dyn_cast<FunctionType>(); 1238 if (!funcType) 1239 return parser.emitError(trailingTypeLoc, "expected function type"); 1240 if (funcType.getNumResults() > 1) 1241 return parser.emitError(trailingTypeLoc, 1242 "expected function with 0 or 1 result"); 1243 if (isDirect) { 1244 // Make sure types match. 1245 if (parser.resolveOperands(operands, funcType.getInputs(), 1246 parser.getNameLoc(), result.operands)) 1247 return failure(); 1248 if (funcType.getNumResults() != 0 && 1249 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1250 result.addTypes(funcType.getResults()); 1251 } else { 1252 Builder &builder = parser.getBuilder(); 1253 Type llvmResultType; 1254 if (funcType.getNumResults() == 0) { 1255 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1256 } else { 1257 llvmResultType = funcType.getResult(0); 1258 if (!isCompatibleType(llvmResultType)) 1259 return parser.emitError(trailingTypeLoc, 1260 "expected result to have LLVM type"); 1261 } 1262 1263 SmallVector<Type, 8> argTypes; 1264 argTypes.reserve(funcType.getNumInputs()); 1265 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1266 auto argType = funcType.getInput(i); 1267 if (!isCompatibleType(argType)) 1268 return parser.emitError(trailingTypeLoc, 1269 "expected LLVM types as inputs"); 1270 argTypes.push_back(argType); 1271 } 1272 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1273 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1274 1275 auto funcArguments = 1276 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1277 1278 // Make sure that the first operand (indirect callee) matches the wrapped 1279 // LLVM IR function type, and that the types of the other call operands 1280 // match the types of the function arguments. 1281 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1282 parser.resolveOperands(funcArguments, funcType.getInputs(), 1283 parser.getNameLoc(), result.operands)) 1284 return failure(); 1285 1286 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1287 result.addTypes(llvmResultType); 1288 } 1289 1290 return success(); 1291 } 1292 1293 //===----------------------------------------------------------------------===// 1294 // Printing/parsing for LLVM::ExtractElementOp. 1295 //===----------------------------------------------------------------------===// 1296 // Expects vector to be of wrapped LLVM vector type and position to be of 1297 // wrapped LLVM i32 type. 1298 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1299 Value vector, Value position, 1300 ArrayRef<NamedAttribute> attrs) { 1301 auto vectorType = vector.getType(); 1302 auto llvmType = LLVM::getVectorElementType(vectorType); 1303 build(b, result, llvmType, vector, position); 1304 result.addAttributes(attrs); 1305 } 1306 1307 void ExtractElementOp::print(OpAsmPrinter &p) { 1308 p << ' ' << getVector() << "[" << getPosition() << " : " 1309 << getPosition().getType() << "]"; 1310 p.printOptionalAttrDict((*this)->getAttrs()); 1311 p << " : " << getVector().getType(); 1312 } 1313 1314 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1315 // attribute-dict? `:` type 1316 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1317 OperationState &result) { 1318 SMLoc loc; 1319 OpAsmParser::UnresolvedOperand vector, position; 1320 Type type, positionType; 1321 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1322 parser.parseLSquare() || parser.parseOperand(position) || 1323 parser.parseColonType(positionType) || parser.parseRSquare() || 1324 parser.parseOptionalAttrDict(result.attributes) || 1325 parser.parseColonType(type) || 1326 parser.resolveOperand(vector, type, result.operands) || 1327 parser.resolveOperand(position, positionType, result.operands)) 1328 return failure(); 1329 if (!LLVM::isCompatibleVectorType(type)) 1330 return parser.emitError( 1331 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1332 result.addTypes(LLVM::getVectorElementType(type)); 1333 return success(); 1334 } 1335 1336 LogicalResult ExtractElementOp::verify() { 1337 Type vectorType = getVector().getType(); 1338 if (!LLVM::isCompatibleVectorType(vectorType)) 1339 return emitOpError("expected LLVM dialect-compatible vector type for " 1340 "operand #1, got") 1341 << vectorType; 1342 Type valueType = LLVM::getVectorElementType(vectorType); 1343 if (valueType != getRes().getType()) 1344 return emitOpError() << "Type mismatch: extracting from " << vectorType 1345 << " should produce " << valueType 1346 << " but this op returns " << getRes().getType(); 1347 return success(); 1348 } 1349 1350 //===----------------------------------------------------------------------===// 1351 // Printing/parsing for LLVM::ExtractValueOp. 1352 //===----------------------------------------------------------------------===// 1353 1354 void ExtractValueOp::print(OpAsmPrinter &p) { 1355 p << ' ' << getContainer() << getPosition(); 1356 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1357 p << " : " << getContainer().getType(); 1358 } 1359 1360 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1361 // `containerType`. Position is an integer array attribute where each value 1362 // is a zero-based position of the element in the aggregate type. Return the 1363 // resulting type wrapped in MLIR, or nullptr on error. 1364 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1365 Type containerType, 1366 ArrayAttr positionAttr, 1367 SMLoc attributeLoc, 1368 SMLoc typeLoc) { 1369 Type llvmType = containerType; 1370 if (!isCompatibleType(containerType)) 1371 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1372 1373 // Infer the element type from the structure type: iteratively step inside the 1374 // type by taking the element type, indexed by the position attribute for 1375 // structures. Check the position index before accessing, it is supposed to 1376 // be in bounds. 1377 for (Attribute subAttr : positionAttr) { 1378 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1379 if (!positionElementAttr) 1380 return parser.emitError(attributeLoc, 1381 "expected an array of integer literals"), 1382 nullptr; 1383 int position = positionElementAttr.getInt(); 1384 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1385 if (position < 0 || 1386 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1387 return parser.emitError(attributeLoc, "position out of bounds"), 1388 nullptr; 1389 llvmType = arrayType.getElementType(); 1390 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1391 if (position < 0 || 1392 static_cast<unsigned>(position) >= structType.getBody().size()) 1393 return parser.emitError(attributeLoc, "position out of bounds"), 1394 nullptr; 1395 llvmType = structType.getBody()[position]; 1396 } else { 1397 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1398 nullptr; 1399 } 1400 } 1401 return llvmType; 1402 } 1403 1404 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1405 // `containerType`. Returns null on failure. 1406 static Type getInsertExtractValueElementType(Type containerType, 1407 ArrayAttr positionAttr, 1408 Operation *op) { 1409 Type llvmType = containerType; 1410 if (!isCompatibleType(containerType)) { 1411 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1412 return {}; 1413 } 1414 1415 // Infer the element type from the structure type: iteratively step inside the 1416 // type by taking the element type, indexed by the position attribute for 1417 // structures. Check the position index before accessing, it is supposed to 1418 // be in bounds. 1419 for (Attribute subAttr : positionAttr) { 1420 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1421 if (!positionElementAttr) { 1422 op->emitOpError("expected an array of integer literals, got: ") 1423 << subAttr; 1424 return {}; 1425 } 1426 int position = positionElementAttr.getInt(); 1427 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1428 if (position < 0 || 1429 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1430 op->emitOpError("position out of bounds: ") << position; 1431 return {}; 1432 } 1433 llvmType = arrayType.getElementType(); 1434 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1435 if (position < 0 || 1436 static_cast<unsigned>(position) >= structType.getBody().size()) { 1437 op->emitOpError("position out of bounds") << position; 1438 return {}; 1439 } 1440 llvmType = structType.getBody()[position]; 1441 } else { 1442 op->emitOpError("expected LLVM IR structure/array type, got: ") 1443 << llvmType; 1444 return {}; 1445 } 1446 } 1447 return llvmType; 1448 } 1449 1450 // <operation> ::= `llvm.extractvalue` ssa-use 1451 // `[` integer-literal (`,` integer-literal)* `]` 1452 // attribute-dict? `:` type 1453 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1454 OpAsmParser::UnresolvedOperand container; 1455 Type containerType; 1456 ArrayAttr positionAttr; 1457 SMLoc attributeLoc, trailingTypeLoc; 1458 1459 if (parser.parseOperand(container) || 1460 parser.getCurrentLocation(&attributeLoc) || 1461 parser.parseAttribute(positionAttr, "position", result.attributes) || 1462 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1463 parser.getCurrentLocation(&trailingTypeLoc) || 1464 parser.parseType(containerType) || 1465 parser.resolveOperand(container, containerType, result.operands)) 1466 return failure(); 1467 1468 auto elementType = getInsertExtractValueElementType( 1469 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1470 if (!elementType) 1471 return failure(); 1472 1473 result.addTypes(elementType); 1474 return success(); 1475 } 1476 1477 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1478 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1479 OpFoldResult result = {}; 1480 while (insertValueOp) { 1481 if (getPosition() == insertValueOp.getPosition()) 1482 return insertValueOp.getValue(); 1483 unsigned min = 1484 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1485 // If one is fully prefix of the other, stop propagating back as it will 1486 // miss dependencies. For instance, %3 should not fold to %f0 in the 1487 // following example: 1488 // ``` 1489 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1490 // !llvm.array<4 x !llvm.array<4xf32>> 1491 // %2 = llvm.insertvalue %arr, %1[0] : 1492 // !llvm.array<4 x !llvm.array<4xf32>> 1493 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1494 // ``` 1495 if (getPosition().getValue().take_front(min) == 1496 insertValueOp.getPosition().getValue().take_front(min)) 1497 return result; 1498 1499 // If neither a prefix, nor the exact position, we can extract out of the 1500 // value being inserted into. Moreover, we can try again if that operand 1501 // is itself an insertvalue expression. 1502 getContainerMutable().assign(insertValueOp.getContainer()); 1503 result = getResult(); 1504 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1505 } 1506 return result; 1507 } 1508 1509 LogicalResult ExtractValueOp::verify() { 1510 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1511 getPositionAttr(), *this); 1512 if (!valueType) 1513 return failure(); 1514 1515 if (getRes().getType() != valueType) 1516 return emitOpError() << "Type mismatch: extracting from " 1517 << getContainer().getType() << " should produce " 1518 << valueType << " but this op returns " 1519 << getRes().getType(); 1520 return success(); 1521 } 1522 1523 //===----------------------------------------------------------------------===// 1524 // Printing/parsing for LLVM::InsertElementOp. 1525 //===----------------------------------------------------------------------===// 1526 1527 void InsertElementOp::print(OpAsmPrinter &p) { 1528 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1529 << getPosition().getType() << "]"; 1530 p.printOptionalAttrDict((*this)->getAttrs()); 1531 p << " : " << getVector().getType(); 1532 } 1533 1534 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1535 // attribute-dict? `:` type 1536 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1537 OperationState &result) { 1538 SMLoc loc; 1539 OpAsmParser::UnresolvedOperand vector, value, position; 1540 Type vectorType, positionType; 1541 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1542 parser.parseComma() || parser.parseOperand(vector) || 1543 parser.parseLSquare() || parser.parseOperand(position) || 1544 parser.parseColonType(positionType) || parser.parseRSquare() || 1545 parser.parseOptionalAttrDict(result.attributes) || 1546 parser.parseColonType(vectorType)) 1547 return failure(); 1548 1549 if (!LLVM::isCompatibleVectorType(vectorType)) 1550 return parser.emitError( 1551 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1552 Type valueType = LLVM::getVectorElementType(vectorType); 1553 if (!valueType) 1554 return failure(); 1555 1556 if (parser.resolveOperand(vector, vectorType, result.operands) || 1557 parser.resolveOperand(value, valueType, result.operands) || 1558 parser.resolveOperand(position, positionType, result.operands)) 1559 return failure(); 1560 1561 result.addTypes(vectorType); 1562 return success(); 1563 } 1564 1565 LogicalResult InsertElementOp::verify() { 1566 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1567 if (valueType != getValue().getType()) 1568 return emitOpError() << "Type mismatch: cannot insert " 1569 << getValue().getType() << " into " 1570 << getVector().getType(); 1571 return success(); 1572 } 1573 1574 //===----------------------------------------------------------------------===// 1575 // Printing/parsing for LLVM::InsertValueOp. 1576 //===----------------------------------------------------------------------===// 1577 1578 void InsertValueOp::print(OpAsmPrinter &p) { 1579 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1580 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1581 p << " : " << getContainer().getType(); 1582 } 1583 1584 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1585 // `[` integer-literal (`,` integer-literal)* `]` 1586 // attribute-dict? `:` type 1587 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1588 OpAsmParser::UnresolvedOperand container, value; 1589 Type containerType; 1590 ArrayAttr positionAttr; 1591 SMLoc attributeLoc, trailingTypeLoc; 1592 1593 if (parser.parseOperand(value) || parser.parseComma() || 1594 parser.parseOperand(container) || 1595 parser.getCurrentLocation(&attributeLoc) || 1596 parser.parseAttribute(positionAttr, "position", result.attributes) || 1597 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1598 parser.getCurrentLocation(&trailingTypeLoc) || 1599 parser.parseType(containerType)) 1600 return failure(); 1601 1602 auto valueType = getInsertExtractValueElementType( 1603 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1604 if (!valueType) 1605 return failure(); 1606 1607 if (parser.resolveOperand(container, containerType, result.operands) || 1608 parser.resolveOperand(value, valueType, result.operands)) 1609 return failure(); 1610 1611 result.addTypes(containerType); 1612 return success(); 1613 } 1614 1615 LogicalResult InsertValueOp::verify() { 1616 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1617 getPositionAttr(), *this); 1618 if (!valueType) 1619 return failure(); 1620 1621 if (getValue().getType() != valueType) 1622 return emitOpError() << "Type mismatch: cannot insert " 1623 << getValue().getType() << " into " 1624 << getContainer().getType(); 1625 1626 return success(); 1627 } 1628 1629 //===----------------------------------------------------------------------===// 1630 // Printing, parsing and verification for LLVM::ReturnOp. 1631 //===----------------------------------------------------------------------===// 1632 1633 LogicalResult ReturnOp::verify() { 1634 if (getNumOperands() > 1) 1635 return emitOpError("expected at most 1 operand"); 1636 1637 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1638 Type expectedType = parent.getFunctionType().getReturnType(); 1639 if (expectedType.isa<LLVMVoidType>()) { 1640 if (getNumOperands() == 0) 1641 return success(); 1642 InFlightDiagnostic diag = emitOpError("expected no operands"); 1643 diag.attachNote(parent->getLoc()) << "when returning from function"; 1644 return diag; 1645 } 1646 if (getNumOperands() == 0) { 1647 if (expectedType.isa<LLVMVoidType>()) 1648 return success(); 1649 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1650 diag.attachNote(parent->getLoc()) << "when returning from function"; 1651 return diag; 1652 } 1653 if (expectedType != getOperand(0).getType()) { 1654 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1655 diag.attachNote(parent->getLoc()) << "when returning from function"; 1656 return diag; 1657 } 1658 } 1659 return success(); 1660 } 1661 1662 //===----------------------------------------------------------------------===// 1663 // ResumeOp 1664 //===----------------------------------------------------------------------===// 1665 1666 LogicalResult ResumeOp::verify() { 1667 if (!getValue().getDefiningOp<LandingpadOp>()) 1668 return emitOpError("expects landingpad value as operand"); 1669 // No check for personality of function - landingpad op verifies it. 1670 return success(); 1671 } 1672 1673 //===----------------------------------------------------------------------===// 1674 // Verifier for LLVM::AddressOfOp. 1675 //===----------------------------------------------------------------------===// 1676 1677 template <typename OpTy> 1678 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1679 Operation *module = parent; 1680 while (module && !satisfiesLLVMModule(module)) 1681 module = module->getParentOp(); 1682 assert(module && "unexpected operation outside of a module"); 1683 return dyn_cast_or_null<OpTy>( 1684 mlir::SymbolTable::lookupSymbolIn(module, name)); 1685 } 1686 1687 GlobalOp AddressOfOp::getGlobal() { 1688 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1689 getGlobalName()); 1690 } 1691 1692 LLVMFuncOp AddressOfOp::getFunction() { 1693 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1694 getGlobalName()); 1695 } 1696 1697 LogicalResult AddressOfOp::verify() { 1698 auto global = getGlobal(); 1699 auto function = getFunction(); 1700 if (!global && !function) 1701 return emitOpError( 1702 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1703 1704 LLVMPointerType type = getType(); 1705 if (global && global.getAddrSpace() != type.getAddressSpace()) 1706 return emitOpError("pointer address space must match address space of the " 1707 "referenced global"); 1708 1709 if (type.isOpaque()) 1710 return success(); 1711 1712 if (global && type.getElementType() != global.getType()) 1713 return emitOpError( 1714 "the type must be a pointer to the type of the referenced global"); 1715 1716 if (function && type.getElementType() != function.getFunctionType()) 1717 return emitOpError( 1718 "the type must be a pointer to the type of the referenced function"); 1719 1720 return success(); 1721 } 1722 1723 //===----------------------------------------------------------------------===// 1724 // Builder, printer and verifier for LLVM::GlobalOp. 1725 //===----------------------------------------------------------------------===// 1726 1727 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1728 bool isConstant, Linkage linkage, StringRef name, 1729 Attribute value, uint64_t alignment, unsigned addrSpace, 1730 bool dsoLocal, bool threadLocal, 1731 ArrayRef<NamedAttribute> attrs) { 1732 result.addAttribute(getSymNameAttrName(result.name), 1733 builder.getStringAttr(name)); 1734 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 1735 if (isConstant) 1736 result.addAttribute(getConstantAttrName(result.name), 1737 builder.getUnitAttr()); 1738 if (value) 1739 result.addAttribute(getValueAttrName(result.name), value); 1740 if (dsoLocal) 1741 result.addAttribute(getDsoLocalAttrName(result.name), 1742 builder.getUnitAttr()); 1743 if (threadLocal) 1744 result.addAttribute(getThreadLocal_AttrName(result.name), 1745 builder.getUnitAttr()); 1746 1747 // Only add an alignment attribute if the "alignment" input 1748 // is different from 0. The value must also be a power of two, but 1749 // this is tested in GlobalOp::verify, not here. 1750 if (alignment != 0) 1751 result.addAttribute(getAlignmentAttrName(result.name), 1752 builder.getI64IntegerAttr(alignment)); 1753 1754 result.addAttribute(getLinkageAttrName(result.name), 1755 LinkageAttr::get(builder.getContext(), linkage)); 1756 if (addrSpace != 0) 1757 result.addAttribute(getAddrSpaceAttrName(result.name), 1758 builder.getI32IntegerAttr(addrSpace)); 1759 result.attributes.append(attrs.begin(), attrs.end()); 1760 result.addRegion(); 1761 } 1762 1763 void GlobalOp::print(OpAsmPrinter &p) { 1764 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1765 if (auto unnamedAddr = getUnnamedAddr()) { 1766 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1767 if (!str.empty()) 1768 p << str << ' '; 1769 } 1770 if (getThreadLocal_()) 1771 p << "thread_local "; 1772 if (getConstant()) 1773 p << "constant "; 1774 p.printSymbolName(getSymName()); 1775 p << '('; 1776 if (auto value = getValueOrNull()) 1777 p.printAttribute(value); 1778 p << ')'; 1779 // Note that the alignment attribute is printed using the 1780 // default syntax here, even though it is an inherent attribute 1781 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1782 p.printOptionalAttrDict( 1783 (*this)->getAttrs(), 1784 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), 1785 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), 1786 getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); 1787 1788 // Print the trailing type unless it's a string global. 1789 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1790 return; 1791 p << " : " << getType(); 1792 1793 Region &initializer = getInitializerRegion(); 1794 if (!initializer.empty()) { 1795 p << ' '; 1796 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1797 } 1798 } 1799 1800 // Parses one of the keywords provided in the list `keywords` and returns the 1801 // position of the parsed keyword in the list. If none of the keywords from the 1802 // list is parsed, returns -1. 1803 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1804 ArrayRef<StringRef> keywords) { 1805 for (const auto &en : llvm::enumerate(keywords)) { 1806 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1807 return en.index(); 1808 } 1809 return -1; 1810 } 1811 1812 namespace { 1813 template <typename Ty> 1814 struct EnumTraits {}; 1815 1816 #define REGISTER_ENUM_TYPE(Ty) \ 1817 template <> \ 1818 struct EnumTraits<Ty> { \ 1819 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1820 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1821 } 1822 1823 REGISTER_ENUM_TYPE(Linkage); 1824 REGISTER_ENUM_TYPE(UnnamedAddr); 1825 REGISTER_ENUM_TYPE(CConv); 1826 } // namespace 1827 1828 /// Parse an enum from the keyword, or default to the provided default value. 1829 /// The return type is the enum type by default, unless overriden with the 1830 /// second template argument. 1831 template <typename EnumTy, typename RetTy = EnumTy> 1832 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1833 OperationState &result, 1834 EnumTy defaultValue) { 1835 SmallVector<StringRef, 10> names; 1836 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1837 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1838 1839 int index = parseOptionalKeywordAlternative(parser, names); 1840 if (index == -1) 1841 return static_cast<RetTy>(defaultValue); 1842 return static_cast<RetTy>(index); 1843 } 1844 1845 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1846 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1847 // align ::= `align` `=` UINT64 1848 // 1849 // The type can be omitted for string attributes, in which case it will be 1850 // inferred from the value of the string as [strlen(value) x i8]. 1851 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1852 MLIRContext *ctx = parser.getContext(); 1853 // Parse optional linkage, default to External. 1854 result.addAttribute(getLinkageAttrName(result.name), 1855 LLVM::LinkageAttr::get( 1856 ctx, parseOptionalLLVMKeyword<Linkage>( 1857 parser, result, LLVM::Linkage::External))); 1858 1859 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 1860 result.addAttribute(getThreadLocal_AttrName(result.name), 1861 parser.getBuilder().getUnitAttr()); 1862 1863 // Parse optional UnnamedAddr, default to None. 1864 result.addAttribute(getUnnamedAddrAttrName(result.name), 1865 parser.getBuilder().getI64IntegerAttr( 1866 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1867 parser, result, LLVM::UnnamedAddr::None))); 1868 1869 if (succeeded(parser.parseOptionalKeyword("constant"))) 1870 result.addAttribute(getConstantAttrName(result.name), 1871 parser.getBuilder().getUnitAttr()); 1872 1873 StringAttr name; 1874 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 1875 result.attributes) || 1876 parser.parseLParen()) 1877 return failure(); 1878 1879 Attribute value; 1880 if (parser.parseOptionalRParen()) { 1881 if (parser.parseAttribute(value, getValueAttrName(result.name), 1882 result.attributes) || 1883 parser.parseRParen()) 1884 return failure(); 1885 } 1886 1887 SmallVector<Type, 1> types; 1888 if (parser.parseOptionalAttrDict(result.attributes) || 1889 parser.parseOptionalColonTypeList(types)) 1890 return failure(); 1891 1892 if (types.size() > 1) 1893 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1894 1895 Region &initRegion = *result.addRegion(); 1896 if (types.empty()) { 1897 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1898 MLIRContext *context = parser.getContext(); 1899 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1900 strAttr.getValue().size()); 1901 types.push_back(arrayType); 1902 } else { 1903 return parser.emitError(parser.getNameLoc(), 1904 "type can only be omitted for string globals"); 1905 } 1906 } else { 1907 OptionalParseResult parseResult = 1908 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1909 /*argTypes=*/{}); 1910 if (parseResult.hasValue() && failed(*parseResult)) 1911 return failure(); 1912 } 1913 1914 result.addAttribute(getGlobalTypeAttrName(result.name), 1915 TypeAttr::get(types[0])); 1916 return success(); 1917 } 1918 1919 static bool isZeroAttribute(Attribute value) { 1920 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1921 return intValue.getValue().isNullValue(); 1922 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1923 return fpValue.getValue().isZero(); 1924 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1925 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1926 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1927 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1928 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1929 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1930 return false; 1931 } 1932 1933 LogicalResult GlobalOp::verify() { 1934 if (!LLVMPointerType::isValidElementType(getType())) 1935 return emitOpError( 1936 "expects type to be a valid element type for an LLVM pointer"); 1937 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1938 return emitOpError("must appear at the module level"); 1939 1940 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1941 auto type = getType().dyn_cast<LLVMArrayType>(); 1942 IntegerType elementType = 1943 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1944 if (!elementType || elementType.getWidth() != 8 || 1945 type.getNumElements() != strAttr.getValue().size()) 1946 return emitOpError( 1947 "requires an i8 array type of the length equal to that of the string " 1948 "attribute"); 1949 } 1950 1951 if (getLinkage() == Linkage::Common) { 1952 if (Attribute value = getValueOrNull()) { 1953 if (!isZeroAttribute(value)) { 1954 return emitOpError() 1955 << "expected zero value for '" 1956 << stringifyLinkage(Linkage::Common) << "' linkage"; 1957 } 1958 } 1959 } 1960 1961 if (getLinkage() == Linkage::Appending) { 1962 if (!getType().isa<LLVMArrayType>()) { 1963 return emitOpError() << "expected array type for '" 1964 << stringifyLinkage(Linkage::Appending) 1965 << "' linkage"; 1966 } 1967 } 1968 1969 Optional<uint64_t> alignAttr = getAlignment(); 1970 if (alignAttr.hasValue()) { 1971 uint64_t value = alignAttr.getValue(); 1972 if (!llvm::isPowerOf2_64(value)) 1973 return emitError() << "alignment attribute is not a power of 2"; 1974 } 1975 1976 return success(); 1977 } 1978 1979 LogicalResult GlobalOp::verifyRegions() { 1980 if (Block *b = getInitializerBlock()) { 1981 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 1982 if (ret.operand_type_begin() == ret.operand_type_end()) 1983 return emitOpError("initializer region cannot return void"); 1984 if (*ret.operand_type_begin() != getType()) 1985 return emitOpError("initializer region type ") 1986 << *ret.operand_type_begin() << " does not match global type " 1987 << getType(); 1988 1989 for (Operation &op : *b) { 1990 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 1991 if (!iface || !iface.hasNoEffect()) 1992 return op.emitError() 1993 << "ops with side effects not allowed in global initializers"; 1994 } 1995 1996 if (getValueOrNull()) 1997 return emitOpError("cannot have both initializer value and region"); 1998 } 1999 2000 return success(); 2001 } 2002 2003 //===----------------------------------------------------------------------===// 2004 // LLVM::GlobalCtorsOp 2005 //===----------------------------------------------------------------------===// 2006 2007 LogicalResult 2008 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2009 for (Attribute ctor : getCtors()) { 2010 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 2011 symbolTable))) 2012 return failure(); 2013 } 2014 return success(); 2015 } 2016 2017 LogicalResult GlobalCtorsOp::verify() { 2018 if (getCtors().size() != getPriorities().size()) 2019 return emitError( 2020 "mismatch between the number of ctors and the number of priorities"); 2021 return success(); 2022 } 2023 2024 //===----------------------------------------------------------------------===// 2025 // LLVM::GlobalDtorsOp 2026 //===----------------------------------------------------------------------===// 2027 2028 LogicalResult 2029 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2030 for (Attribute dtor : getDtors()) { 2031 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 2032 symbolTable))) 2033 return failure(); 2034 } 2035 return success(); 2036 } 2037 2038 LogicalResult GlobalDtorsOp::verify() { 2039 if (getDtors().size() != getPriorities().size()) 2040 return emitError( 2041 "mismatch between the number of dtors and the number of priorities"); 2042 return success(); 2043 } 2044 2045 //===----------------------------------------------------------------------===// 2046 // Printing/parsing for LLVM::ShuffleVectorOp. 2047 //===----------------------------------------------------------------------===// 2048 // Expects vector to be of wrapped LLVM vector type and position to be of 2049 // wrapped LLVM i32 type. 2050 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 2051 Value v1, Value v2, ArrayAttr mask, 2052 ArrayRef<NamedAttribute> attrs) { 2053 auto containerType = v1.getType(); 2054 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), 2055 mask.size(), 2056 LLVM::isScalableVectorType(containerType)); 2057 build(b, result, vType, v1, v2, mask); 2058 result.addAttributes(attrs); 2059 } 2060 2061 void ShuffleVectorOp::print(OpAsmPrinter &p) { 2062 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 2063 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 2064 p << " : " << getV1().getType() << ", " << getV2().getType(); 2065 } 2066 2067 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 2068 // `[` integer-literal (`,` integer-literal)* `]` 2069 // attribute-dict? `:` type 2070 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 2071 OperationState &result) { 2072 SMLoc loc; 2073 OpAsmParser::UnresolvedOperand v1, v2; 2074 ArrayAttr maskAttr; 2075 Type typeV1, typeV2; 2076 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 2077 parser.parseComma() || parser.parseOperand(v2) || 2078 parser.parseAttribute(maskAttr, "mask", result.attributes) || 2079 parser.parseOptionalAttrDict(result.attributes) || 2080 parser.parseColonType(typeV1) || parser.parseComma() || 2081 parser.parseType(typeV2) || 2082 parser.resolveOperand(v1, typeV1, result.operands) || 2083 parser.resolveOperand(v2, typeV2, result.operands)) 2084 return failure(); 2085 if (!LLVM::isCompatibleVectorType(typeV1)) 2086 return parser.emitError( 2087 loc, "expected LLVM IR dialect vector type for operand #1"); 2088 auto vType = 2089 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 2090 typeV1.cast<VectorType>().isScalable()); 2091 result.addTypes(vType); 2092 return success(); 2093 } 2094 2095 LogicalResult ShuffleVectorOp::verify() { 2096 Type type1 = getV1().getType(); 2097 Type type2 = getV2().getType(); 2098 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 2099 return emitOpError("expected matching LLVM IR Dialect element types"); 2100 if (LLVM::isScalableVectorType(type1)) 2101 if (llvm::any_of(getMask(), [](Attribute attr) { 2102 return attr.cast<IntegerAttr>().getInt() != 0; 2103 })) 2104 return emitOpError("expected a splat operation for scalable vectors"); 2105 return success(); 2106 } 2107 2108 //===----------------------------------------------------------------------===// 2109 // Implementations for LLVM::LLVMFuncOp. 2110 //===----------------------------------------------------------------------===// 2111 2112 // Add the entry block to the function. 2113 Block *LLVMFuncOp::addEntryBlock() { 2114 assert(empty() && "function already has an entry block"); 2115 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 2116 2117 auto *entry = new Block; 2118 push_back(entry); 2119 2120 // FIXME: Allow passing in proper locations for the entry arguments. 2121 LLVMFunctionType type = getFunctionType(); 2122 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2123 entry->addArgument(type.getParamType(i), getLoc()); 2124 return entry; 2125 } 2126 2127 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2128 StringRef name, Type type, LLVM::Linkage linkage, 2129 bool dsoLocal, CConv cconv, 2130 ArrayRef<NamedAttribute> attrs, 2131 ArrayRef<DictionaryAttr> argAttrs) { 2132 result.addRegion(); 2133 result.addAttribute(SymbolTable::getSymbolAttrName(), 2134 builder.getStringAttr(name)); 2135 result.addAttribute(getFunctionTypeAttrName(result.name), 2136 TypeAttr::get(type)); 2137 result.addAttribute(getLinkageAttrName(result.name), 2138 LinkageAttr::get(builder.getContext(), linkage)); 2139 result.addAttribute(getCConvAttrName(result.name), 2140 CConvAttr::get(builder.getContext(), cconv)); 2141 result.attributes.append(attrs.begin(), attrs.end()); 2142 if (dsoLocal) 2143 result.addAttribute("dso_local", builder.getUnitAttr()); 2144 if (argAttrs.empty()) 2145 return; 2146 2147 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 2148 "expected as many argument attribute lists as arguments"); 2149 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 2150 /*resultAttrs=*/llvm::None); 2151 } 2152 2153 // Builds an LLVM function type from the given lists of input and output types. 2154 // Returns a null type if any of the types provided are non-LLVM types, or if 2155 // there is more than one output type. 2156 static Type 2157 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2158 ArrayRef<Type> outputs, 2159 function_interface_impl::VariadicFlag variadicFlag) { 2160 Builder &b = parser.getBuilder(); 2161 if (outputs.size() > 1) { 2162 parser.emitError(loc, "failed to construct function type: expected zero or " 2163 "one function result"); 2164 return {}; 2165 } 2166 2167 // Convert inputs to LLVM types, exit early on error. 2168 SmallVector<Type, 4> llvmInputs; 2169 for (auto t : inputs) { 2170 if (!isCompatibleType(t)) { 2171 parser.emitError(loc, "failed to construct function type: expected LLVM " 2172 "type for function arguments"); 2173 return {}; 2174 } 2175 llvmInputs.push_back(t); 2176 } 2177 2178 // No output is denoted as "void" in LLVM type system. 2179 Type llvmOutput = 2180 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2181 if (!isCompatibleType(llvmOutput)) { 2182 parser.emitError(loc, "failed to construct function type: expected LLVM " 2183 "type for function results") 2184 << llvmOutput; 2185 return {}; 2186 } 2187 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2188 variadicFlag.isVariadic()); 2189 } 2190 2191 // Parses an LLVM function. 2192 // 2193 // operation ::= `llvm.func` linkage? cconv? function-signature 2194 // function-attributes? 2195 // function-body 2196 // 2197 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2198 // Default to external linkage if no keyword is provided. 2199 result.addAttribute( 2200 getLinkageAttrName(result.name), 2201 LinkageAttr::get(parser.getContext(), 2202 parseOptionalLLVMKeyword<Linkage>( 2203 parser, result, LLVM::Linkage::External))); 2204 2205 // Default to C Calling Convention if no keyword is provided. 2206 result.addAttribute( 2207 getCConvAttrName(result.name), 2208 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 2209 parser, result, LLVM::CConv::C))); 2210 2211 StringAttr nameAttr; 2212 SmallVector<OpAsmParser::Argument> entryArgs; 2213 SmallVector<DictionaryAttr> resultAttrs; 2214 SmallVector<Type> resultTypes; 2215 bool isVariadic; 2216 2217 auto signatureLocation = parser.getCurrentLocation(); 2218 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2219 result.attributes) || 2220 function_interface_impl::parseFunctionSignature( 2221 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, 2222 resultAttrs)) 2223 return failure(); 2224 2225 SmallVector<Type> argTypes; 2226 for (auto &arg : entryArgs) 2227 argTypes.push_back(arg.type); 2228 auto type = 2229 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2230 function_interface_impl::VariadicFlag(isVariadic)); 2231 if (!type) 2232 return failure(); 2233 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2234 TypeAttr::get(type)); 2235 2236 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2237 return failure(); 2238 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2239 entryArgs, resultAttrs); 2240 2241 auto *body = result.addRegion(); 2242 OptionalParseResult parseResult = 2243 parser.parseOptionalRegion(*body, entryArgs); 2244 return failure(parseResult.hasValue() && failed(*parseResult)); 2245 } 2246 2247 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2248 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2249 // the external linkage since it is the default value. 2250 void LLVMFuncOp::print(OpAsmPrinter &p) { 2251 p << ' '; 2252 if (getLinkage() != LLVM::Linkage::External) 2253 p << stringifyLinkage(getLinkage()) << ' '; 2254 if (getCConv() != LLVM::CConv::C) 2255 p << stringifyCConv(getCConv()) << ' '; 2256 2257 p.printSymbolName(getName()); 2258 2259 LLVMFunctionType fnType = getFunctionType(); 2260 SmallVector<Type, 8> argTypes; 2261 SmallVector<Type, 1> resTypes; 2262 argTypes.reserve(fnType.getNumParams()); 2263 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2264 argTypes.push_back(fnType.getParamType(i)); 2265 2266 Type returnType = fnType.getReturnType(); 2267 if (!returnType.isa<LLVMVoidType>()) 2268 resTypes.push_back(returnType); 2269 2270 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2271 isVarArg(), resTypes); 2272 function_interface_impl::printFunctionAttributes( 2273 p, *this, argTypes.size(), resTypes.size(), 2274 {getLinkageAttrName(), getCConvAttrName()}); 2275 2276 // Print the body if this is not an external function. 2277 Region &body = getBody(); 2278 if (!body.empty()) { 2279 p << ' '; 2280 p.printRegion(body, /*printEntryBlockArgs=*/false, 2281 /*printBlockTerminators=*/true); 2282 } 2283 } 2284 2285 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2286 // - functions don't have 'common' linkage 2287 // - external functions have 'external' or 'extern_weak' linkage; 2288 // - vararg is (currently) only supported for external functions; 2289 LogicalResult LLVMFuncOp::verify() { 2290 if (getLinkage() == LLVM::Linkage::Common) 2291 return emitOpError() << "functions cannot have '" 2292 << stringifyLinkage(LLVM::Linkage::Common) 2293 << "' linkage"; 2294 2295 // Check to see if this function has a void return with a result attribute to 2296 // it. It isn't clear what semantics we would assign to that. 2297 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2298 !getResultAttrs(0).empty()) { 2299 return emitOpError() 2300 << "cannot attach result attributes to functions with a void return"; 2301 } 2302 2303 if (isExternal()) { 2304 if (getLinkage() != LLVM::Linkage::External && 2305 getLinkage() != LLVM::Linkage::ExternWeak) 2306 return emitOpError() << "external functions must have '" 2307 << stringifyLinkage(LLVM::Linkage::External) 2308 << "' or '" 2309 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2310 << "' linkage"; 2311 return success(); 2312 } 2313 2314 if (isVarArg()) 2315 return emitOpError("only external functions can be variadic"); 2316 2317 return success(); 2318 } 2319 2320 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2321 /// - entry block arguments are of LLVM types. 2322 LogicalResult LLVMFuncOp::verifyRegions() { 2323 if (isExternal()) 2324 return success(); 2325 2326 unsigned numArguments = getFunctionType().getNumParams(); 2327 Block &entryBlock = front(); 2328 for (unsigned i = 0; i < numArguments; ++i) { 2329 Type argType = entryBlock.getArgument(i).getType(); 2330 if (!isCompatibleType(argType)) 2331 return emitOpError("entry block argument #") 2332 << i << " is not of LLVM type"; 2333 } 2334 2335 return success(); 2336 } 2337 2338 //===----------------------------------------------------------------------===// 2339 // Verification for LLVM::ConstantOp. 2340 //===----------------------------------------------------------------------===// 2341 2342 LogicalResult LLVM::ConstantOp::verify() { 2343 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2344 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2345 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2346 !arrayType.getElementType().isInteger(8)) { 2347 return emitOpError() << "expected array type of " 2348 << sAttr.getValue().size() 2349 << " i8 elements for the string constant"; 2350 } 2351 return success(); 2352 } 2353 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2354 if (structType.getBody().size() != 2 || 2355 structType.getBody()[0] != structType.getBody()[1]) { 2356 return emitError() << "expected struct type with two elements of the " 2357 "same type, the type of a complex constant"; 2358 } 2359 2360 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2361 if (!arrayAttr || arrayAttr.size() != 2 || 2362 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2363 return emitOpError() << "expected array attribute with two elements, " 2364 "representing a complex constant"; 2365 } 2366 2367 Type elementType = structType.getBody()[0]; 2368 if (!elementType 2369 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2370 return emitError() 2371 << "expected struct element types to be floating point type or " 2372 "integer type"; 2373 } 2374 return success(); 2375 } 2376 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2377 return emitOpError() 2378 << "only supports integer, float, string or elements attributes"; 2379 return success(); 2380 } 2381 2382 // Constant op constant-folds to its value. 2383 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2384 2385 //===----------------------------------------------------------------------===// 2386 // Utility functions for parsing atomic ops 2387 //===----------------------------------------------------------------------===// 2388 2389 // Helper function to parse a keyword into the specified attribute named by 2390 // `attrName`. The keyword must match one of the string values defined by the 2391 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2392 // state. 2393 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2394 StringRef attrName) { 2395 SMLoc loc; 2396 StringRef keyword; 2397 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2398 return failure(); 2399 2400 // Replace the keyword `keyword` with an integer attribute. 2401 auto kind = symbolizeAtomicBinOp(keyword); 2402 if (!kind) { 2403 return parser.emitError(loc) 2404 << "'" << keyword << "' is an incorrect value of the '" << attrName 2405 << "' attribute"; 2406 } 2407 2408 auto value = static_cast<int64_t>(kind.getValue()); 2409 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2410 result.addAttribute(attrName, attr); 2411 2412 return success(); 2413 } 2414 2415 // Helper function to parse a keyword into the specified attribute named by 2416 // `attrName`. The keyword must match one of the string values defined by the 2417 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2418 // state. 2419 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2420 OperationState &result, 2421 StringRef attrName) { 2422 SMLoc loc; 2423 StringRef ordering; 2424 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2425 return failure(); 2426 2427 // Replace the keyword `ordering` with an integer attribute. 2428 auto kind = symbolizeAtomicOrdering(ordering); 2429 if (!kind) { 2430 return parser.emitError(loc) 2431 << "'" << ordering << "' is an incorrect value of the '" << attrName 2432 << "' attribute"; 2433 } 2434 2435 auto value = static_cast<int64_t>(kind.getValue()); 2436 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2437 result.addAttribute(attrName, attr); 2438 2439 return success(); 2440 } 2441 2442 //===----------------------------------------------------------------------===// 2443 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2444 //===----------------------------------------------------------------------===// 2445 2446 void AtomicRMWOp::print(OpAsmPrinter &p) { 2447 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2448 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2449 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2450 p << " : " << getRes().getType(); 2451 } 2452 2453 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2454 // attribute-dict? `:` type 2455 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2456 Type type; 2457 OpAsmParser::UnresolvedOperand ptr, val; 2458 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2459 parser.parseComma() || parser.parseOperand(val) || 2460 parseAtomicOrdering(parser, result, "ordering") || 2461 parser.parseOptionalAttrDict(result.attributes) || 2462 parser.parseColonType(type) || 2463 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2464 result.operands) || 2465 parser.resolveOperand(val, type, result.operands)) 2466 return failure(); 2467 2468 result.addTypes(type); 2469 return success(); 2470 } 2471 2472 LogicalResult AtomicRMWOp::verify() { 2473 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2474 auto valType = getVal().getType(); 2475 if (valType != ptrType.getElementType()) 2476 return emitOpError("expected LLVM IR element type for operand #0 to " 2477 "match type for operand #1"); 2478 auto resType = getRes().getType(); 2479 if (resType != valType) 2480 return emitOpError( 2481 "expected LLVM IR result type to match type for operand #1"); 2482 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2483 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2484 return emitOpError("expected LLVM IR floating point type"); 2485 } else if (getBinOp() == AtomicBinOp::xchg) { 2486 auto intType = valType.dyn_cast<IntegerType>(); 2487 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2488 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2489 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2490 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2491 !valType.isa<Float64Type>()) 2492 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2493 } else { 2494 auto intType = valType.dyn_cast<IntegerType>(); 2495 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2496 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2497 intBitWidth != 64) 2498 return emitOpError("expected LLVM IR integer type"); 2499 } 2500 2501 if (static_cast<unsigned>(getOrdering()) < 2502 static_cast<unsigned>(AtomicOrdering::monotonic)) 2503 return emitOpError() << "expected at least '" 2504 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2505 << "' ordering"; 2506 2507 return success(); 2508 } 2509 2510 //===----------------------------------------------------------------------===// 2511 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2512 //===----------------------------------------------------------------------===// 2513 2514 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2515 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2516 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2517 << stringifyAtomicOrdering(getFailureOrdering()); 2518 p.printOptionalAttrDict((*this)->getAttrs(), 2519 {"success_ordering", "failure_ordering"}); 2520 p << " : " << getVal().getType(); 2521 } 2522 2523 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2524 // keyword keyword attribute-dict? `:` type 2525 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2526 OperationState &result) { 2527 auto &builder = parser.getBuilder(); 2528 Type type; 2529 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2530 if (parser.parseOperand(ptr) || parser.parseComma() || 2531 parser.parseOperand(cmp) || parser.parseComma() || 2532 parser.parseOperand(val) || 2533 parseAtomicOrdering(parser, result, "success_ordering") || 2534 parseAtomicOrdering(parser, result, "failure_ordering") || 2535 parser.parseOptionalAttrDict(result.attributes) || 2536 parser.parseColonType(type) || 2537 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2538 result.operands) || 2539 parser.resolveOperand(cmp, type, result.operands) || 2540 parser.resolveOperand(val, type, result.operands)) 2541 return failure(); 2542 2543 auto boolType = IntegerType::get(builder.getContext(), 1); 2544 auto resultType = 2545 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2546 result.addTypes(resultType); 2547 2548 return success(); 2549 } 2550 2551 LogicalResult AtomicCmpXchgOp::verify() { 2552 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2553 if (!ptrType) 2554 return emitOpError("expected LLVM IR pointer type for operand #0"); 2555 auto cmpType = getCmp().getType(); 2556 auto valType = getVal().getType(); 2557 if (cmpType != ptrType.getElementType() || cmpType != valType) 2558 return emitOpError("expected LLVM IR element type for operand #0 to " 2559 "match type for all other operands"); 2560 auto intType = valType.dyn_cast<IntegerType>(); 2561 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2562 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2563 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2564 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2565 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2566 return emitOpError("unexpected LLVM IR type"); 2567 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2568 getFailureOrdering() < AtomicOrdering::monotonic) 2569 return emitOpError("ordering must be at least 'monotonic'"); 2570 if (getFailureOrdering() == AtomicOrdering::release || 2571 getFailureOrdering() == AtomicOrdering::acq_rel) 2572 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2573 return success(); 2574 } 2575 2576 //===----------------------------------------------------------------------===// 2577 // Printer, parser and verifier for LLVM::FenceOp. 2578 //===----------------------------------------------------------------------===// 2579 2580 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2581 // attribute-dict? 2582 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2583 StringAttr sScope; 2584 StringRef syncscopeKeyword = "syncscope"; 2585 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2586 if (parser.parseLParen() || 2587 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2588 parser.parseRParen()) 2589 return failure(); 2590 } else { 2591 result.addAttribute(syncscopeKeyword, 2592 parser.getBuilder().getStringAttr("")); 2593 } 2594 if (parseAtomicOrdering(parser, result, "ordering") || 2595 parser.parseOptionalAttrDict(result.attributes)) 2596 return failure(); 2597 return success(); 2598 } 2599 2600 void FenceOp::print(OpAsmPrinter &p) { 2601 StringRef syncscopeKeyword = "syncscope"; 2602 p << ' '; 2603 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2604 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2605 p << stringifyAtomicOrdering(getOrdering()); 2606 } 2607 2608 LogicalResult FenceOp::verify() { 2609 if (getOrdering() == AtomicOrdering::not_atomic || 2610 getOrdering() == AtomicOrdering::unordered || 2611 getOrdering() == AtomicOrdering::monotonic) 2612 return emitOpError("can be given only acquire, release, acq_rel, " 2613 "and seq_cst orderings"); 2614 return success(); 2615 } 2616 2617 //===----------------------------------------------------------------------===// 2618 // Folder for LLVM::BitcastOp 2619 //===----------------------------------------------------------------------===// 2620 2621 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2622 // bitcast(x : T0, T0) -> x 2623 if (getArg().getType() == getType()) 2624 return getArg(); 2625 // bitcast(bitcast(x : T0, T1), T0) -> x 2626 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2627 if (prev.getArg().getType() == getType()) 2628 return prev.getArg(); 2629 return {}; 2630 } 2631 2632 //===----------------------------------------------------------------------===// 2633 // Folder for LLVM::AddrSpaceCastOp 2634 //===----------------------------------------------------------------------===// 2635 2636 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2637 // addrcast(x : T0, T0) -> x 2638 if (getArg().getType() == getType()) 2639 return getArg(); 2640 // addrcast(addrcast(x : T0, T1), T0) -> x 2641 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2642 if (prev.getArg().getType() == getType()) 2643 return prev.getArg(); 2644 return {}; 2645 } 2646 2647 //===----------------------------------------------------------------------===// 2648 // Folder for LLVM::GEPOp 2649 //===----------------------------------------------------------------------===// 2650 2651 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2652 // gep %x:T, 0 -> %x 2653 if (getBase().getType() == getType() && getIndices().size() == 1 && 2654 matchPattern(getIndices()[0], m_Zero())) 2655 return getBase(); 2656 return {}; 2657 } 2658 2659 //===----------------------------------------------------------------------===// 2660 // LLVMDialect initialization, type parsing, and registration. 2661 //===----------------------------------------------------------------------===// 2662 2663 void LLVMDialect::initialize() { 2664 addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>(); 2665 2666 // clang-format off 2667 addTypes<LLVMVoidType, 2668 LLVMPPCFP128Type, 2669 LLVMX86MMXType, 2670 LLVMTokenType, 2671 LLVMLabelType, 2672 LLVMMetadataType, 2673 LLVMFunctionType, 2674 LLVMPointerType, 2675 LLVMFixedVectorType, 2676 LLVMScalableVectorType, 2677 LLVMArrayType, 2678 LLVMStructType>(); 2679 // clang-format on 2680 addOperations< 2681 #define GET_OP_LIST 2682 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2683 , 2684 #define GET_OP_LIST 2685 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 2686 >(); 2687 2688 // Support unknown operations because not all LLVM operations are registered. 2689 allowUnknownOperations(); 2690 } 2691 2692 #define GET_OP_CLASSES 2693 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2694 2695 /// Parse a type registered to this dialect. 2696 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2697 return detail::parseType(parser); 2698 } 2699 2700 /// Print a type registered to this dialect. 2701 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2702 return detail::printType(type, os); 2703 } 2704 2705 LogicalResult LLVMDialect::verifyDataLayoutString( 2706 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2707 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2708 llvm::DataLayout::parse(descr); 2709 if (maybeDataLayout) 2710 return success(); 2711 2712 std::string message; 2713 llvm::raw_string_ostream messageStream(message); 2714 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2715 reportError("invalid data layout descriptor: " + messageStream.str()); 2716 return failure(); 2717 } 2718 2719 /// Verify LLVM dialect attributes. 2720 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2721 NamedAttribute attr) { 2722 // If the `llvm.loop` attribute is present, enforce the following structure, 2723 // which the module translation can assume. 2724 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2725 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2726 if (!loopAttr) 2727 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2728 << "' to be a dictionary attribute"; 2729 Optional<NamedAttribute> parallelAccessGroup = 2730 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2731 if (parallelAccessGroup.hasValue()) { 2732 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2733 if (!accessGroups) 2734 return op->emitOpError() 2735 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2736 << "' to be an array attribute"; 2737 for (Attribute attr : accessGroups) { 2738 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2739 if (!accessGroupRef) 2740 return op->emitOpError() 2741 << "expected '" << attr << "' to be a symbol reference"; 2742 StringAttr metadataName = accessGroupRef.getRootReference(); 2743 auto metadataOp = 2744 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2745 op->getParentOp(), metadataName); 2746 if (!metadataOp) 2747 return op->emitOpError() 2748 << "expected '" << attr << "' to reference a metadata op"; 2749 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2750 Operation *accessGroupOp = 2751 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2752 if (!accessGroupOp) 2753 return op->emitOpError() 2754 << "expected '" << attr << "' to reference an access_group op"; 2755 } 2756 } 2757 2758 Optional<NamedAttribute> loopOptions = 2759 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2760 if (loopOptions.hasValue() && 2761 !loopOptions->getValue().isa<LoopOptionsAttr>()) 2762 return op->emitOpError() 2763 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2764 << "' to be a `loopopts` attribute"; 2765 } 2766 2767 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2768 return op->emitOpError() 2769 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2770 << "' is permitted only in argument or result attributes"; 2771 } 2772 2773 // If the data layout attribute is present, it must use the LLVM data layout 2774 // syntax. Try parsing it and report errors in case of failure. Users of this 2775 // attribute may assume it is well-formed and can pass it to the (asserting) 2776 // llvm::DataLayout constructor. 2777 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2778 return success(); 2779 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2780 return verifyDataLayoutString( 2781 stringAttr.getValue(), 2782 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2783 2784 return op->emitOpError() << "expected '" 2785 << LLVM::LLVMDialect::getDataLayoutAttrName() 2786 << "' to be a string attributes"; 2787 } 2788 2789 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2790 Type annotatedType) { 2791 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2792 if (!structType) { 2793 const auto emitIncorrectAnnotatedType = [&op]() { 2794 return op->emitError() 2795 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2796 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2797 }; 2798 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2799 if (!ptrType) 2800 return emitIncorrectAnnotatedType(); 2801 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2802 if (!structType) 2803 return emitIncorrectAnnotatedType(); 2804 } 2805 2806 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2807 if (!arrAttrs) 2808 return op->emitError() << "expected '" 2809 << LLVMDialect::getStructAttrsAttrName() 2810 << "' to be an array attribute"; 2811 2812 if (structType.getBody().size() != arrAttrs.size()) 2813 return op->emitError() 2814 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2815 << "' must match the size of the annotated '!llvm.struct'"; 2816 return success(); 2817 } 2818 2819 static LogicalResult verifyFuncOpInterfaceStructAttr( 2820 Operation *op, Attribute attr, 2821 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) { 2822 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2823 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2824 return op->emitError() << "expected '" 2825 << LLVMDialect::getStructAttrsAttrName() 2826 << "' to be used on function-like operations"; 2827 } 2828 2829 /// Verify LLVMIR function argument attributes. 2830 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2831 unsigned regionIdx, 2832 unsigned argIdx, 2833 NamedAttribute argAttr) { 2834 // Check that llvm.noalias is a unit attribute. 2835 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2836 !argAttr.getValue().isa<UnitAttr>()) 2837 return op->emitError() 2838 << "expected llvm.noalias argument attribute to be a unit attribute"; 2839 // Check that llvm.align is an integer attribute. 2840 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2841 !argAttr.getValue().isa<IntegerAttr>()) 2842 return op->emitError() 2843 << "llvm.align argument attribute of non integer type"; 2844 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2845 return verifyFuncOpInterfaceStructAttr( 2846 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2847 return funcOp.getArgumentTypes()[argIdx]; 2848 }); 2849 } 2850 return success(); 2851 } 2852 2853 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2854 unsigned regionIdx, 2855 unsigned resIdx, 2856 NamedAttribute resAttr) { 2857 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2858 return verifyFuncOpInterfaceStructAttr( 2859 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2860 return funcOp.getResultTypes()[resIdx]; 2861 }); 2862 } 2863 return success(); 2864 } 2865 2866 //===----------------------------------------------------------------------===// 2867 // Utility functions. 2868 //===----------------------------------------------------------------------===// 2869 2870 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2871 StringRef name, StringRef value, 2872 LLVM::Linkage linkage) { 2873 assert(builder.getInsertionBlock() && 2874 builder.getInsertionBlock()->getParentOp() && 2875 "expected builder to point to a block constrained in an op"); 2876 auto module = 2877 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2878 assert(module && "builder points to an op outside of a module"); 2879 2880 // Create the global at the entry of the module. 2881 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2882 MLIRContext *ctx = builder.getContext(); 2883 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2884 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2885 loc, type, /*isConstant=*/true, linkage, name, 2886 builder.getStringAttr(value), /*alignment=*/0); 2887 2888 // Get the pointer to the first character in the global string. 2889 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2890 Value cst0 = builder.create<LLVM::ConstantOp>( 2891 loc, IntegerType::get(ctx, 64), 2892 builder.getIntegerAttr(builder.getIndexType(), 0)); 2893 return builder.create<LLVM::GEPOp>( 2894 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2895 ValueRange{cst0, cst0}); 2896 } 2897 2898 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2899 return op->hasTrait<OpTrait::SymbolTable>() && 2900 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2901 } 2902 2903 void FMFAttr::print(AsmPrinter &printer) const { 2904 printer << "<"; 2905 printer << stringifyFastmathFlags(this->getFlags()); 2906 printer << ">"; 2907 } 2908 2909 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2910 if (failed(parser.parseLess())) 2911 return {}; 2912 2913 FastmathFlags flags = {}; 2914 if (failed(parser.parseOptionalGreater())) { 2915 auto parseFlags = [&]() -> ParseResult { 2916 StringRef elemName; 2917 if (failed(parser.parseKeyword(&elemName))) 2918 return failure(); 2919 2920 auto elem = symbolizeFastmathFlags(elemName); 2921 if (!elem) 2922 return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2923 << elemName; 2924 2925 flags = flags | *elem; 2926 return success(); 2927 }; 2928 if (failed(parser.parseCommaSeparatedList(parseFlags)) || 2929 parser.parseGreater()) 2930 return {}; 2931 } 2932 2933 return FMFAttr::get(parser.getContext(), flags); 2934 } 2935 2936 void LinkageAttr::print(AsmPrinter &printer) const { 2937 printer << "<"; 2938 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2939 printer << stringifyEnum(getLinkage()); 2940 else 2941 printer << static_cast<uint64_t>(getLinkage()); 2942 printer << ">"; 2943 } 2944 2945 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2946 StringRef elemName; 2947 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2948 parser.parseGreater()) 2949 return {}; 2950 auto elem = linkage::symbolizeLinkage(elemName); 2951 if (!elem) { 2952 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2953 return {}; 2954 } 2955 Linkage linkage = *elem; 2956 return LinkageAttr::get(parser.getContext(), linkage); 2957 } 2958 2959 void CConvAttr::print(AsmPrinter &printer) const { 2960 printer << "<"; 2961 if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv()) 2962 printer << stringifyEnum(getCallingConv()); 2963 else 2964 printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv()); 2965 printer << ">"; 2966 } 2967 2968 Attribute CConvAttr::parse(AsmParser &parser, Type type) { 2969 StringRef convName; 2970 2971 if (parser.parseLess() || parser.parseKeyword(&convName) || 2972 parser.parseGreater()) 2973 return {}; 2974 auto cconv = cconv::symbolizeCConv(convName); 2975 if (!cconv) { 2976 parser.emitError(parser.getNameLoc(), "unknown calling convention: ") 2977 << convName; 2978 return {}; 2979 } 2980 CConv cconvVal = *cconv; 2981 return CConvAttr::get(parser.getContext(), cconvVal); 2982 } 2983 2984 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 2985 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 2986 2987 template <typename T> 2988 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 2989 Optional<T> value) { 2990 auto option = llvm::find_if( 2991 options, [tag](auto option) { return option.first == tag; }); 2992 if (option != options.end()) { 2993 if (value.hasValue()) 2994 option->second = *value; 2995 else 2996 options.erase(option); 2997 } else { 2998 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 2999 } 3000 return *this; 3001 } 3002 3003 LoopOptionsAttrBuilder & 3004 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 3005 return setOption(LoopOptionCase::disable_licm, value); 3006 } 3007 3008 /// Set the `interleave_count` option to the provided value. If no value 3009 /// is provided the option is deleted. 3010 LoopOptionsAttrBuilder & 3011 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 3012 return setOption(LoopOptionCase::interleave_count, count); 3013 } 3014 3015 /// Set the `disable_unroll` option to the provided value. If no value 3016 /// is provided the option is deleted. 3017 LoopOptionsAttrBuilder & 3018 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 3019 return setOption(LoopOptionCase::disable_unroll, value); 3020 } 3021 3022 /// Set the `disable_pipeline` option to the provided value. If no value 3023 /// is provided the option is deleted. 3024 LoopOptionsAttrBuilder & 3025 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 3026 return setOption(LoopOptionCase::disable_pipeline, value); 3027 } 3028 3029 /// Set the `pipeline_initiation_interval` option to the provided value. 3030 /// If no value is provided the option is deleted. 3031 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 3032 Optional<uint64_t> count) { 3033 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 3034 } 3035 3036 template <typename T> 3037 static Optional<T> 3038 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 3039 LoopOptionCase option) { 3040 auto it = 3041 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 3042 return optionPair.first < option; 3043 }); 3044 if (it == options.end()) 3045 return {}; 3046 return static_cast<T>(it->second); 3047 } 3048 3049 Optional<bool> LoopOptionsAttr::disableUnroll() { 3050 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 3051 } 3052 3053 Optional<bool> LoopOptionsAttr::disableLICM() { 3054 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 3055 } 3056 3057 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 3058 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 3059 } 3060 3061 /// Build the LoopOptions Attribute from a sorted array of individual options. 3062 LoopOptionsAttr LoopOptionsAttr::get( 3063 MLIRContext *context, 3064 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 3065 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 3066 "LoopOptionsAttr ctor expects a sorted options array"); 3067 return Base::get(context, sortedOptions); 3068 } 3069 3070 /// Build the LoopOptions Attribute from a sorted array of individual options. 3071 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 3072 LoopOptionsAttrBuilder &optionBuilders) { 3073 llvm::sort(optionBuilders.options, llvm::less_first()); 3074 return Base::get(context, optionBuilders.options); 3075 } 3076 3077 void LoopOptionsAttr::print(AsmPrinter &printer) const { 3078 printer << "<"; 3079 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 3080 printer << stringifyEnum(option.first) << " = "; 3081 switch (option.first) { 3082 case LoopOptionCase::disable_licm: 3083 case LoopOptionCase::disable_unroll: 3084 case LoopOptionCase::disable_pipeline: 3085 printer << (option.second ? "true" : "false"); 3086 break; 3087 case LoopOptionCase::interleave_count: 3088 case LoopOptionCase::pipeline_initiation_interval: 3089 printer << option.second; 3090 break; 3091 } 3092 }); 3093 printer << ">"; 3094 } 3095 3096 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 3097 if (failed(parser.parseLess())) 3098 return {}; 3099 3100 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 3101 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 3102 auto parseLoopOptions = [&]() -> ParseResult { 3103 StringRef optionName; 3104 if (parser.parseKeyword(&optionName)) 3105 return failure(); 3106 3107 auto option = symbolizeLoopOptionCase(optionName); 3108 if (!option) 3109 return parser.emitError(parser.getNameLoc(), "unknown loop option: ") 3110 << optionName; 3111 if (!seenOptions.insert(*option).second) 3112 return parser.emitError(parser.getNameLoc(), "loop option present twice"); 3113 if (failed(parser.parseEqual())) 3114 return failure(); 3115 3116 int64_t value; 3117 switch (*option) { 3118 case LoopOptionCase::disable_licm: 3119 case LoopOptionCase::disable_unroll: 3120 case LoopOptionCase::disable_pipeline: 3121 if (succeeded(parser.parseOptionalKeyword("true"))) 3122 value = 1; 3123 else if (succeeded(parser.parseOptionalKeyword("false"))) 3124 value = 0; 3125 else { 3126 return parser.emitError(parser.getNameLoc(), 3127 "expected boolean value 'true' or 'false'"); 3128 } 3129 break; 3130 case LoopOptionCase::interleave_count: 3131 case LoopOptionCase::pipeline_initiation_interval: 3132 if (failed(parser.parseInteger(value))) 3133 return parser.emitError(parser.getNameLoc(), "expected integer value"); 3134 break; 3135 } 3136 options.push_back(std::make_pair(*option, value)); 3137 return success(); 3138 }; 3139 if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater()) 3140 return {}; 3141 3142 llvm::sort(options, llvm::less_first()); 3143 return get(parser.getContext(), options); 3144 } 3145