1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Error.h" 33 #include "llvm/Support/Mutex.h" 34 #include "llvm/Support/SourceMgr.h" 35 36 #include <numeric> 37 38 using namespace mlir; 39 using namespace mlir::LLVM; 40 using mlir::LLVM::cconv::getMaxEnumValForCConv; 41 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 42 43 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 44 45 static constexpr const char kVolatileAttrName[] = "volatile_"; 46 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 47 static constexpr const char kElemTypeAttrName[] = "elem_type"; 48 49 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 50 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 51 #define GET_ATTRDEF_CLASSES 52 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 53 54 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 55 SmallVector<NamedAttribute, 8> filteredAttrs( 56 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 57 if (attr.getName() == "fastmathFlags") { 58 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 59 return defAttr != attr.getValue(); 60 } 61 return true; 62 })); 63 return filteredAttrs; 64 } 65 66 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 67 NamedAttrList &result) { 68 return parser.parseOptionalAttrDict(result); 69 } 70 71 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 72 DictionaryAttr attrs) { 73 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 74 } 75 76 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 77 /// fully defined llvm.func. 78 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 79 Operation *op, 80 SymbolTableCollection &symbolTable) { 81 StringRef name = symbol.getValue(); 82 auto func = 83 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 84 if (!func) 85 return op->emitOpError("'") 86 << name << "' does not reference a valid LLVM function"; 87 if (func.isExternal()) 88 return op->emitOpError("'") << name << "' does not have a definition"; 89 return success(); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // Printing, parsing and builder for LLVM::CmpOp. 94 //===----------------------------------------------------------------------===// 95 96 void ICmpOp::build(OpBuilder &builder, OperationState &result, 97 ICmpPredicate predicate, Value lhs, Value rhs) { 98 auto boolType = IntegerType::get(lhs.getType().getContext(), 1); 99 if (LLVM::isCompatibleVectorType(lhs.getType()) || 100 LLVM::isCompatibleVectorType(rhs.getType())) { 101 int64_t numLHSElements = 1, numRHSElements = 1; 102 if (LLVM::isCompatibleVectorType(lhs.getType())) 103 numLHSElements = 104 LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); 105 if (LLVM::isCompatibleVectorType(rhs.getType())) 106 numRHSElements = 107 LLVM::getVectorNumElements(rhs.getType()).getFixedValue(); 108 build(builder, result, 109 VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType), 110 predicate, lhs, rhs); 111 } else { 112 build(builder, result, boolType, predicate, lhs, rhs); 113 } 114 } 115 116 void ICmpOp::print(OpAsmPrinter &p) { 117 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 118 << ", " << getOperand(1); 119 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 120 p << " : " << getLhs().getType(); 121 } 122 123 void FCmpOp::print(OpAsmPrinter &p) { 124 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 125 << ", " << getOperand(1); 126 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 127 p << " : " << getLhs().getType(); 128 } 129 130 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 131 // attribute-dict? `:` type 132 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 133 // attribute-dict? `:` type 134 template <typename CmpPredicateType> 135 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 136 Builder &builder = parser.getBuilder(); 137 138 StringAttr predicateAttr; 139 OpAsmParser::UnresolvedOperand lhs, rhs; 140 Type type; 141 SMLoc predicateLoc, trailingTypeLoc; 142 if (parser.getCurrentLocation(&predicateLoc) || 143 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 144 parser.parseOperand(lhs) || parser.parseComma() || 145 parser.parseOperand(rhs) || 146 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 147 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 148 parser.resolveOperand(lhs, type, result.operands) || 149 parser.resolveOperand(rhs, type, result.operands)) 150 return failure(); 151 152 // Replace the string attribute `predicate` with an integer attribute. 153 int64_t predicateValue = 0; 154 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 155 Optional<ICmpPredicate> predicate = 156 symbolizeICmpPredicate(predicateAttr.getValue()); 157 if (!predicate) 158 return parser.emitError(predicateLoc) 159 << "'" << predicateAttr.getValue() 160 << "' is an incorrect value of the 'predicate' attribute"; 161 predicateValue = static_cast<int64_t>(*predicate); 162 } else { 163 Optional<FCmpPredicate> predicate = 164 symbolizeFCmpPredicate(predicateAttr.getValue()); 165 if (!predicate) 166 return parser.emitError(predicateLoc) 167 << "'" << predicateAttr.getValue() 168 << "' is an incorrect value of the 'predicate' attribute"; 169 predicateValue = static_cast<int64_t>(*predicate); 170 } 171 172 result.attributes.set("predicate", 173 parser.getBuilder().getI64IntegerAttr(predicateValue)); 174 175 // The result type is either i1 or a vector type <? x i1> if the inputs are 176 // vectors. 177 Type resultType = IntegerType::get(builder.getContext(), 1); 178 if (!isCompatibleType(type)) 179 return parser.emitError(trailingTypeLoc, 180 "expected LLVM dialect-compatible type"); 181 if (LLVM::isCompatibleVectorType(type)) { 182 if (LLVM::isScalableVectorType(type)) { 183 resultType = LLVM::getVectorType( 184 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 185 /*isScalable=*/true); 186 } else { 187 resultType = LLVM::getVectorType( 188 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 189 /*isScalable=*/false); 190 } 191 } 192 193 result.addTypes({resultType}); 194 return success(); 195 } 196 197 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 198 return parseCmpOp<ICmpPredicate>(parser, result); 199 } 200 201 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 202 return parseCmpOp<FCmpPredicate>(parser, result); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // Printing, parsing and verification for LLVM::AllocaOp. 207 //===----------------------------------------------------------------------===// 208 209 void AllocaOp::print(OpAsmPrinter &p) { 210 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 211 if (!elemTy) 212 elemTy = *getElemType(); 213 214 auto funcTy = 215 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 216 217 p << ' ' << getArraySize() << " x " << elemTy; 218 if (getAlignment() && *getAlignment() != 0) 219 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName}); 220 else 221 p.printOptionalAttrDict((*this)->getAttrs(), 222 {"alignment", kElemTypeAttrName}); 223 p << " : " << funcTy; 224 } 225 226 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 227 // `:` type `,` type 228 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 229 OpAsmParser::UnresolvedOperand arraySize; 230 Type type, elemType; 231 SMLoc trailingTypeLoc; 232 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 233 parser.parseType(elemType) || 234 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 235 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 236 return failure(); 237 238 Optional<NamedAttribute> alignmentAttr = 239 result.attributes.getNamed("alignment"); 240 if (alignmentAttr.has_value()) { 241 auto alignmentInt = 242 alignmentAttr.value().getValue().dyn_cast<IntegerAttr>(); 243 if (!alignmentInt) 244 return parser.emitError(parser.getNameLoc(), 245 "expected integer alignment"); 246 if (alignmentInt.getValue().isNullValue()) 247 result.attributes.erase("alignment"); 248 } 249 250 // Extract the result type from the trailing function type. 251 auto funcType = type.dyn_cast<FunctionType>(); 252 if (!funcType || funcType.getNumInputs() != 1 || 253 funcType.getNumResults() != 1) 254 return parser.emitError( 255 trailingTypeLoc, 256 "expected trailing function type with one argument and one result"); 257 258 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 259 return failure(); 260 261 Type resultType = funcType.getResult(0); 262 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) { 263 if (ptrResultType.isOpaque()) 264 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 265 } 266 267 result.addTypes({funcType.getResult(0)}); 268 return success(); 269 } 270 271 /// Checks that the elemental type is present in either the pointer type or 272 /// the attribute, but not both. 273 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, 274 Optional<Type> ptrElementType) { 275 if (ptrType.isOpaque() && !ptrElementType.has_value()) { 276 return op->emitOpError() << "expected '" << kElemTypeAttrName 277 << "' attribute if opaque pointer type is used"; 278 } 279 if (!ptrType.isOpaque() && ptrElementType.has_value()) { 280 return op->emitOpError() 281 << "unexpected '" << kElemTypeAttrName 282 << "' attribute when non-opaque pointer type is used"; 283 } 284 return success(); 285 } 286 287 LogicalResult AllocaOp::verify() { 288 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(), 289 getElemType()); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // LLVM::BrOp 294 //===----------------------------------------------------------------------===// 295 296 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 297 assert(index == 0 && "invalid successor index"); 298 return SuccessorOperands(getDestOperandsMutable()); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // LLVM::CondBrOp 303 //===----------------------------------------------------------------------===// 304 305 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 306 assert(index < getNumSuccessors() && "invalid successor index"); 307 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 308 : getFalseDestOperandsMutable()); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // LLVM::SwitchOp 313 //===----------------------------------------------------------------------===// 314 315 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 316 Block *defaultDestination, ValueRange defaultOperands, 317 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 318 ArrayRef<ValueRange> caseOperands, 319 ArrayRef<int32_t> branchWeights) { 320 ElementsAttr caseValuesAttr; 321 if (!caseValues.empty()) 322 caseValuesAttr = builder.getI32VectorAttr(caseValues); 323 324 ElementsAttr weightsAttr; 325 if (!branchWeights.empty()) 326 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 327 328 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 329 weightsAttr, defaultDestination, caseDestinations); 330 } 331 332 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 333 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 334 static ParseResult parseSwitchOpCases( 335 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 336 SmallVectorImpl<Block *> &caseDestinations, 337 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 338 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 339 SmallVector<APInt> values; 340 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 341 do { 342 int64_t value = 0; 343 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 344 if (values.empty() && !integerParseResult.hasValue()) 345 return success(); 346 347 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 348 return failure(); 349 values.push_back(APInt(bitWidth, value)); 350 351 Block *destination; 352 SmallVector<OpAsmParser::UnresolvedOperand> operands; 353 SmallVector<Type> operandTypes; 354 if (parser.parseColon() || parser.parseSuccessor(destination)) 355 return failure(); 356 if (!parser.parseOptionalLParen()) { 357 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 358 /*allowResultNumber=*/false) || 359 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 360 return failure(); 361 } 362 caseDestinations.push_back(destination); 363 caseOperands.emplace_back(operands); 364 caseOperandTypes.emplace_back(operandTypes); 365 } while (!parser.parseOptionalComma()); 366 367 ShapedType caseValueType = 368 VectorType::get(static_cast<int64_t>(values.size()), flagType); 369 caseValues = DenseIntElementsAttr::get(caseValueType, values); 370 return success(); 371 } 372 373 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 374 ElementsAttr caseValues, 375 SuccessorRange caseDestinations, 376 OperandRangeRange caseOperands, 377 const TypeRangeRange &caseOperandTypes) { 378 if (!caseValues) 379 return; 380 381 size_t index = 0; 382 llvm::interleave( 383 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 384 [&](auto i) { 385 p << " "; 386 p << std::get<0>(i).getLimitedValue(); 387 p << ": "; 388 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 389 }, 390 [&] { 391 p << ','; 392 p.printNewline(); 393 }); 394 p.printNewline(); 395 } 396 397 LogicalResult SwitchOp::verify() { 398 if ((!getCaseValues() && !getCaseDestinations().empty()) || 399 (getCaseValues() && 400 getCaseValues()->size() != 401 static_cast<int64_t>(getCaseDestinations().size()))) 402 return emitOpError("expects number of case values to match number of " 403 "case destinations"); 404 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 405 return emitError("expects number of branch weights to match number of " 406 "successors: ") 407 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 408 return success(); 409 } 410 411 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 412 assert(index < getNumSuccessors() && "invalid successor index"); 413 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 414 : getCaseOperandsMutable(index - 1)); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // Code for LLVM::GEPOp. 419 //===----------------------------------------------------------------------===// 420 421 constexpr int GEPOp::kDynamicIndex; 422 423 namespace { 424 /// Base class for llvm::Error related to GEP index. 425 class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> { 426 protected: 427 unsigned indexPos; 428 429 public: 430 static char ID; 431 432 std::error_code convertToErrorCode() const override { 433 return llvm::inconvertibleErrorCode(); 434 } 435 436 explicit GEPIndexError(unsigned pos) : indexPos(pos) {} 437 }; 438 439 /// llvm::Error for out-of-bound GEP index. 440 struct GEPIndexOutOfBoundError 441 : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> { 442 static char ID; 443 444 using ErrorInfo::ErrorInfo; 445 446 void log(llvm::raw_ostream &os) const override { 447 os << "index " << indexPos << " indexing a struct is out of bounds"; 448 } 449 }; 450 451 /// llvm::Error for non-static GEP index indexing a struct. 452 struct GEPStaticIndexError 453 : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> { 454 static char ID; 455 456 using ErrorInfo::ErrorInfo; 457 458 void log(llvm::raw_ostream &os) const override { 459 os << "expected index " << indexPos << " indexing a struct " 460 << "to be constant"; 461 } 462 }; 463 } // end anonymous namespace 464 465 char GEPIndexError::ID = 0; 466 char GEPIndexOutOfBoundError::ID = 0; 467 char GEPStaticIndexError::ID = 0; 468 469 /// For the given `structIndices` and `indices`, check if they're complied 470 /// with `baseGEPType`, especially check against LLVMStructTypes nested within, 471 /// and refine/promote struct index from `indices` to `updatedStructIndices` 472 /// if the latter argument is not null. 473 static llvm::Error 474 recordStructIndices(Type baseGEPType, unsigned indexPos, 475 ArrayRef<int32_t> structIndices, ValueRange indices, 476 SmallVectorImpl<int32_t> *updatedStructIndices, 477 SmallVectorImpl<Value> *remainingIndices) { 478 if (indexPos >= structIndices.size()) 479 // Stop searching 480 return llvm::Error::success(); 481 482 int32_t gepIndex = structIndices[indexPos]; 483 bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; 484 485 unsigned dynamicIndexPos = indexPos; 486 if (!isStaticIndex) 487 dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), 488 LLVM::GEPOp::kDynamicIndex) - 489 1; 490 491 return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType) 492 .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error { 493 // We don't always want to refine the index (e.g. when performing 494 // verification), so we only refine when updatedStructIndices is not 495 // null. 496 if (!isStaticIndex && updatedStructIndices) { 497 // Try to refine. 498 APInt staticIndexValue; 499 isStaticIndex = matchPattern(indices[dynamicIndexPos], 500 m_ConstantInt(&staticIndexValue)); 501 if (isStaticIndex) { 502 assert(staticIndexValue.getBitWidth() <= 64 && 503 llvm::isInt<32>(staticIndexValue.getLimitedValue()) && 504 "struct index can't fit within int32_t"); 505 gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 506 } 507 } 508 if (!isStaticIndex) 509 return llvm::make_error<GEPStaticIndexError>(indexPos); 510 511 ArrayRef<Type> elementTypes = structType.getBody(); 512 if (gepIndex < 0 || 513 static_cast<size_t>(gepIndex) >= elementTypes.size()) 514 return llvm::make_error<GEPIndexOutOfBoundError>(indexPos); 515 516 if (updatedStructIndices) 517 (*updatedStructIndices)[indexPos] = gepIndex; 518 519 // Instead of recusively going into every children types, we only 520 // dive into the one indexed by gepIndex. 521 return recordStructIndices(elementTypes[gepIndex], indexPos + 1, 522 structIndices, indices, updatedStructIndices, 523 remainingIndices); 524 }) 525 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 526 LLVMArrayType>([&](auto containerType) -> llvm::Error { 527 // Currently we don't refine non-struct index even if it's static. 528 if (remainingIndices) 529 remainingIndices->push_back(indices[dynamicIndexPos]); 530 return recordStructIndices(containerType.getElementType(), indexPos + 1, 531 structIndices, indices, updatedStructIndices, 532 remainingIndices); 533 }) 534 .Default( 535 [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); 536 } 537 538 /// Driver function around `recordStructIndices`. Note that we always check 539 /// from the second GEP index since the first one is always dynamic. 540 static llvm::Error 541 findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices, 542 ValueRange indices, 543 SmallVectorImpl<int32_t> *updatedStructIndices = nullptr, 544 SmallVectorImpl<Value> *remainingIndices = nullptr) { 545 if (remainingIndices) 546 // The first GEP index is always dynamic. 547 remainingIndices->push_back(indices[0]); 548 return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, 549 indices, updatedStructIndices, remainingIndices); 550 } 551 552 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 553 Value basePtr, ValueRange operands, 554 ArrayRef<NamedAttribute> attributes) { 555 build(builder, result, resultType, basePtr, operands, 556 SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes); 557 } 558 559 /// Returns the elemental type of any LLVM-compatible vector type or self. 560 static Type extractVectorElementType(Type type) { 561 if (auto vectorType = type.dyn_cast<VectorType>()) 562 return vectorType.getElementType(); 563 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 564 return scalableVectorType.getElementType(); 565 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 566 return fixedVectorType.getElementType(); 567 return type; 568 } 569 570 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 571 Value basePtr, ValueRange indices, 572 ArrayRef<int32_t> structIndices, 573 ArrayRef<NamedAttribute> attributes) { 574 auto ptrType = 575 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>(); 576 assert(!ptrType.isOpaque() && 577 "expected non-opaque pointer, provide elementType explicitly when " 578 "opaque pointers are used"); 579 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, 580 structIndices, attributes); 581 } 582 583 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 584 Type elementType, Value basePtr, ValueRange indices, 585 ArrayRef<int32_t> structIndices, 586 ArrayRef<NamedAttribute> attributes) { 587 SmallVector<Value> remainingIndices; 588 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 589 structIndices.end()); 590 if (llvm::Error err = 591 findStructIndices(elementType, structIndices, indices, 592 &updatedStructIndices, &remainingIndices)) 593 llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); 594 595 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 596 updatedStructIndices, kDynamicIndex)) && 597 "expected as many index operands as dynamic index attr elements"); 598 599 result.addTypes(resultType); 600 result.addAttributes(attributes); 601 result.addAttribute("structIndices", 602 builder.getI32TensorAttr(updatedStructIndices)); 603 if (extractVectorElementType(basePtr.getType()) 604 .cast<LLVMPointerType>() 605 .isOpaque()) 606 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 607 result.addOperands(basePtr); 608 result.addOperands(remainingIndices); 609 } 610 611 static ParseResult 612 parseGEPIndices(OpAsmParser &parser, 613 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 614 DenseIntElementsAttr &structIndices) { 615 SmallVector<int32_t> constantIndices; 616 617 auto idxParser = [&]() -> ParseResult { 618 int32_t constantIndex; 619 OptionalParseResult parsedInteger = 620 parser.parseOptionalInteger(constantIndex); 621 if (parsedInteger.hasValue()) { 622 if (failed(parsedInteger.getValue())) 623 return failure(); 624 constantIndices.push_back(constantIndex); 625 return success(); 626 } 627 628 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 629 return parser.parseOperand(indices.emplace_back()); 630 }; 631 if (parser.parseCommaSeparatedList(idxParser)) 632 return failure(); 633 634 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 635 return success(); 636 } 637 638 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 639 OperandRange indices, 640 DenseIntElementsAttr structIndices) { 641 unsigned operandIdx = 0; 642 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 643 [&](int32_t cst) { 644 if (cst == LLVM::GEPOp::kDynamicIndex) 645 printer.printOperand(indices[operandIdx++]); 646 else 647 printer << cst; 648 }); 649 } 650 651 LogicalResult LLVM::GEPOp::verify() { 652 if (failed(verifyOpaquePtr( 653 getOperation(), 654 extractVectorElementType(getType()).cast<LLVMPointerType>(), 655 getElemType()))) 656 return failure(); 657 658 auto structIndexRange = getStructIndices().getValues<int32_t>(); 659 // structIndexRange is a kind of iterator, which cannot be converted 660 // to ArrayRef directly. 661 SmallVector<int32_t> structIndices(structIndexRange.size()); 662 for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size())) 663 structIndices[i] = structIndexRange[i]; 664 if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices, 665 getIndices())) 666 return emitOpError() << llvm::toString(std::move(err)); 667 668 return success(); 669 } 670 671 Type LLVM::GEPOp::getSourceElementType() { 672 if (Optional<Type> elemType = getElemType()) 673 return *elemType; 674 675 return extractVectorElementType(getBase().getType()) 676 .cast<LLVMPointerType>() 677 .getElementType(); 678 } 679 680 //===----------------------------------------------------------------------===// 681 // Builder, printer and parser for for LLVM::LoadOp. 682 //===----------------------------------------------------------------------===// 683 684 LogicalResult verifySymbolAttribute( 685 Operation *op, StringRef attributeName, 686 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 687 verifySymbolType) { 688 if (Attribute attribute = op->getAttr(attributeName)) { 689 // The attribute is already verified to be a symbol ref array attribute via 690 // a constraint in the operation definition. 691 for (SymbolRefAttr symbolRef : 692 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 693 StringAttr metadataName = symbolRef.getRootReference(); 694 StringAttr symbolName = symbolRef.getLeafReference(); 695 // We want @metadata::@symbol, not just @symbol 696 if (metadataName == symbolName) { 697 return op->emitOpError() << "expected '" << symbolRef 698 << "' to specify a fully qualified reference"; 699 } 700 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 701 op->getParentOp(), metadataName); 702 if (!metadataOp) 703 return op->emitOpError() 704 << "expected '" << symbolRef << "' to reference a metadata op"; 705 Operation *symbolOp = 706 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 707 if (!symbolOp) 708 return op->emitOpError() 709 << "expected '" << symbolRef << "' to be a valid reference"; 710 if (failed(verifySymbolType(symbolOp, symbolRef))) { 711 return failure(); 712 } 713 } 714 } 715 return success(); 716 } 717 718 // Verifies that metadata ops are wired up properly. 719 template <typename OpTy> 720 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 721 auto verifySymbolType = [op](Operation *symbolOp, 722 SymbolRefAttr symbolRef) -> LogicalResult { 723 if (!isa<OpTy>(symbolOp)) { 724 return op->emitOpError() 725 << "expected '" << symbolRef << "' to resolve to a " 726 << OpTy::getOperationName(); 727 } 728 return success(); 729 }; 730 731 return verifySymbolAttribute(op, attributeName, verifySymbolType); 732 } 733 734 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 735 // access_groups 736 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 737 op, LLVMDialect::getAccessGroupsAttrName()))) 738 return failure(); 739 740 // alias_scopes 741 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 742 op, LLVMDialect::getAliasScopesAttrName()))) 743 return failure(); 744 745 // noalias_scopes 746 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 747 op, LLVMDialect::getNoAliasScopesAttrName()))) 748 return failure(); 749 750 return success(); 751 } 752 753 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 754 755 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 756 Value addr, unsigned alignment, bool isVolatile, 757 bool isNonTemporal) { 758 result.addOperands(addr); 759 result.addTypes(t); 760 if (isVolatile) 761 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 762 if (isNonTemporal) 763 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 764 if (alignment != 0) 765 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 766 } 767 768 void LoadOp::print(OpAsmPrinter &p) { 769 p << ' '; 770 if (getVolatile_()) 771 p << "volatile "; 772 p << getAddr(); 773 p.printOptionalAttrDict((*this)->getAttrs(), 774 {kVolatileAttrName, kElemTypeAttrName}); 775 p << " : " << getAddr().getType(); 776 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 777 p << " -> " << getType(); 778 } 779 780 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 781 // the resulting type if any, null type if opaque pointers are used, and None 782 // if the given type is not the pointer type. 783 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type, 784 SMLoc trailingTypeLoc) { 785 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 786 if (!llvmTy) { 787 parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); 788 return llvm::None; 789 } 790 return llvmTy.getElementType(); 791 } 792 793 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 794 // (`->` type)? 795 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 796 OpAsmParser::UnresolvedOperand addr; 797 Type type; 798 SMLoc trailingTypeLoc; 799 800 if (succeeded(parser.parseOptionalKeyword("volatile"))) 801 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 802 803 if (parser.parseOperand(addr) || 804 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 805 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 806 parser.resolveOperand(addr, type, result.operands)) 807 return failure(); 808 809 Optional<Type> elemTy = 810 getLoadStoreElementType(parser, type, trailingTypeLoc); 811 if (!elemTy) 812 return failure(); 813 if (*elemTy) { 814 result.addTypes(*elemTy); 815 return success(); 816 } 817 818 Type trailingType; 819 if (parser.parseArrow() || parser.parseType(trailingType)) 820 return failure(); 821 result.addTypes(trailingType); 822 return success(); 823 } 824 825 //===----------------------------------------------------------------------===// 826 // Builder, printer and parser for LLVM::StoreOp. 827 //===----------------------------------------------------------------------===// 828 829 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 830 831 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 832 Value addr, unsigned alignment, bool isVolatile, 833 bool isNonTemporal) { 834 result.addOperands({value, addr}); 835 result.addTypes({}); 836 if (isVolatile) 837 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 838 if (isNonTemporal) 839 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 840 if (alignment != 0) 841 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 842 } 843 844 void StoreOp::print(OpAsmPrinter &p) { 845 p << ' '; 846 if (getVolatile_()) 847 p << "volatile "; 848 p << getValue() << ", " << getAddr(); 849 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 850 p << " : "; 851 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 852 p << getValue().getType() << ", "; 853 p << getAddr().getType(); 854 } 855 856 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 857 // attribute-dict? `:` type (`,` type)? 858 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 859 OpAsmParser::UnresolvedOperand addr, value; 860 Type type; 861 SMLoc trailingTypeLoc; 862 863 if (succeeded(parser.parseOptionalKeyword("volatile"))) 864 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 865 866 if (parser.parseOperand(value) || parser.parseComma() || 867 parser.parseOperand(addr) || 868 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 869 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 870 return failure(); 871 872 Type operandType; 873 if (succeeded(parser.parseOptionalComma())) { 874 operandType = type; 875 if (parser.parseType(type)) 876 return failure(); 877 } else { 878 Optional<Type> maybeOperandType = 879 getLoadStoreElementType(parser, type, trailingTypeLoc); 880 if (!maybeOperandType) 881 return failure(); 882 operandType = *maybeOperandType; 883 } 884 885 if (parser.resolveOperand(value, operandType, result.operands) || 886 parser.resolveOperand(addr, type, result.operands)) 887 return failure(); 888 889 return success(); 890 } 891 892 ///===---------------------------------------------------------------------===// 893 /// LLVM::InvokeOp 894 ///===---------------------------------------------------------------------===// 895 896 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 897 assert(index < getNumSuccessors() && "invalid successor index"); 898 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 899 : getUnwindDestOperandsMutable()); 900 } 901 902 CallInterfaceCallable InvokeOp::getCallableForCallee() { 903 // Direct call. 904 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 905 return calleeAttr; 906 // Indirect call, callee Value is the first operand. 907 return getOperand(0); 908 } 909 910 Operation::operand_range InvokeOp::getArgOperands() { 911 return getOperands().drop_front(getCallee().has_value() ? 0 : 1); 912 } 913 914 LogicalResult InvokeOp::verify() { 915 if (getNumResults() > 1) 916 return emitOpError("must have 0 or 1 result"); 917 918 Block *unwindDest = getUnwindDest(); 919 if (unwindDest->empty()) 920 return emitError("must have at least one operation in unwind destination"); 921 922 // In unwind destination, first operation must be LandingpadOp 923 if (!isa<LandingpadOp>(unwindDest->front())) 924 return emitError("first operation in unwind destination should be a " 925 "llvm.landingpad operation"); 926 927 return success(); 928 } 929 930 void InvokeOp::print(OpAsmPrinter &p) { 931 auto callee = getCallee(); 932 bool isDirect = callee.has_value(); 933 934 p << ' '; 935 936 // Either function name or pointer 937 if (isDirect) 938 p.printSymbolName(callee.value()); 939 else 940 p << getOperand(0); 941 942 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 943 p << " to "; 944 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 945 p << " unwind "; 946 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 947 948 p.printOptionalAttrDict((*this)->getAttrs(), 949 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 950 p << " : "; 951 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 952 getResultTypes()); 953 } 954 955 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 956 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 957 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 958 /// attribute-dict? `:` function-type 959 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 960 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 961 FunctionType funcType; 962 SymbolRefAttr funcAttr; 963 SMLoc trailingTypeLoc; 964 Block *normalDest, *unwindDest; 965 SmallVector<Value, 4> normalOperands, unwindOperands; 966 Builder &builder = parser.getBuilder(); 967 968 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 969 // case of an indirect call, there will be 1 operand before `(`. In case of a 970 // direct call, there will be no operands and the parser will stop at the 971 // function identifier without complaining. 972 if (parser.parseOperandList(operands)) 973 return failure(); 974 bool isDirect = operands.empty(); 975 976 // Optionally parse a function identifier. 977 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 978 return failure(); 979 980 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 981 parser.parseKeyword("to") || 982 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 983 parser.parseKeyword("unwind") || 984 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 985 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 986 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 987 return failure(); 988 989 if (isDirect) { 990 // Make sure types match. 991 if (parser.resolveOperands(operands, funcType.getInputs(), 992 parser.getNameLoc(), result.operands)) 993 return failure(); 994 result.addTypes(funcType.getResults()); 995 } else { 996 // Construct the LLVM IR Dialect function type that the first operand 997 // should match. 998 if (funcType.getNumResults() > 1) 999 return parser.emitError(trailingTypeLoc, 1000 "expected function with 0 or 1 result"); 1001 1002 Type llvmResultType; 1003 if (funcType.getNumResults() == 0) { 1004 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1005 } else { 1006 llvmResultType = funcType.getResult(0); 1007 if (!isCompatibleType(llvmResultType)) 1008 return parser.emitError(trailingTypeLoc, 1009 "expected result to have LLVM type"); 1010 } 1011 1012 SmallVector<Type, 8> argTypes; 1013 argTypes.reserve(funcType.getNumInputs()); 1014 for (Type ty : funcType.getInputs()) { 1015 if (isCompatibleType(ty)) 1016 argTypes.push_back(ty); 1017 else 1018 return parser.emitError(trailingTypeLoc, 1019 "expected LLVM types as inputs"); 1020 } 1021 1022 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1023 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1024 1025 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 1026 1027 // Make sure that the first operand (indirect callee) matches the wrapped 1028 // LLVM IR function type, and that the types of the other call operands 1029 // match the types of the function arguments. 1030 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1031 parser.resolveOperands(funcArguments, funcType.getInputs(), 1032 parser.getNameLoc(), result.operands)) 1033 return failure(); 1034 1035 result.addTypes(llvmResultType); 1036 } 1037 result.addSuccessors({normalDest, unwindDest}); 1038 result.addOperands(normalOperands); 1039 result.addOperands(unwindOperands); 1040 1041 result.addAttribute( 1042 InvokeOp::getOperandSegmentSizeAttr(), 1043 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 1044 static_cast<int32_t>(normalOperands.size()), 1045 static_cast<int32_t>(unwindOperands.size())})); 1046 return success(); 1047 } 1048 1049 ///===----------------------------------------------------------------------===// 1050 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 1051 ///===----------------------------------------------------------------------===// 1052 1053 LogicalResult LandingpadOp::verify() { 1054 Value value; 1055 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 1056 if (!func.getPersonality()) 1057 return emitError( 1058 "llvm.landingpad needs to be in a function with a personality"); 1059 } 1060 1061 if (!getCleanup() && getOperands().empty()) 1062 return emitError("landingpad instruction expects at least one clause or " 1063 "cleanup attribute"); 1064 1065 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 1066 value = getOperand(idx); 1067 bool isFilter = value.getType().isa<LLVMArrayType>(); 1068 if (isFilter) { 1069 // FIXME: Verify filter clauses when arrays are appropriately handled 1070 } else { 1071 // catch - global addresses only. 1072 // Bitcast ops should have global addresses as their args. 1073 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 1074 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 1075 continue; 1076 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 1077 << "global addresses expected as operand to " 1078 "bitcast used in clauses for landingpad"; 1079 } 1080 // NullOp and AddressOfOp allowed 1081 if (value.getDefiningOp<NullOp>()) 1082 continue; 1083 if (value.getDefiningOp<AddressOfOp>()) 1084 continue; 1085 return emitError("clause #") 1086 << idx << " is not a known constant - null, addressof, bitcast"; 1087 } 1088 } 1089 return success(); 1090 } 1091 1092 void LandingpadOp::print(OpAsmPrinter &p) { 1093 p << (getCleanup() ? " cleanup " : " "); 1094 1095 // Clauses 1096 for (auto value : getOperands()) { 1097 // Similar to llvm - if clause is an array type then it is filter 1098 // clause else catch clause 1099 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 1100 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1101 << value.getType() << ") "; 1102 } 1103 1104 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1105 1106 p << ": " << getType(); 1107 } 1108 1109 /// <operation> ::= `llvm.landingpad` `cleanup`? 1110 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1111 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1112 // Check for cleanup 1113 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1114 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1115 1116 // Parse clauses with types 1117 while (succeeded(parser.parseOptionalLParen()) && 1118 (succeeded(parser.parseOptionalKeyword("filter")) || 1119 succeeded(parser.parseOptionalKeyword("catch")))) { 1120 OpAsmParser::UnresolvedOperand operand; 1121 Type ty; 1122 if (parser.parseOperand(operand) || parser.parseColon() || 1123 parser.parseType(ty) || 1124 parser.resolveOperand(operand, ty, result.operands) || 1125 parser.parseRParen()) 1126 return failure(); 1127 } 1128 1129 Type type; 1130 if (parser.parseColon() || parser.parseType(type)) 1131 return failure(); 1132 1133 result.addTypes(type); 1134 return success(); 1135 } 1136 1137 //===----------------------------------------------------------------------===// 1138 // Verifying/Printing/parsing for LLVM::CallOp. 1139 //===----------------------------------------------------------------------===// 1140 1141 CallInterfaceCallable CallOp::getCallableForCallee() { 1142 // Direct call. 1143 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 1144 return calleeAttr; 1145 // Indirect call, callee Value is the first operand. 1146 return getOperand(0); 1147 } 1148 1149 Operation::operand_range CallOp::getArgOperands() { 1150 return getOperands().drop_front(getCallee().has_value() ? 0 : 1); 1151 } 1152 1153 LogicalResult CallOp::verify() { 1154 if (getNumResults() > 1) 1155 return emitOpError("must have 0 or 1 result"); 1156 1157 // Type for the callee, we'll get it differently depending if it is a direct 1158 // or indirect call. 1159 Type fnType; 1160 1161 bool isIndirect = false; 1162 1163 // If this is an indirect call, the callee attribute is missing. 1164 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1165 if (!calleeName) { 1166 isIndirect = true; 1167 if (!getNumOperands()) 1168 return emitOpError( 1169 "must have either a `callee` attribute or at least an operand"); 1170 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 1171 if (!ptrType) 1172 return emitOpError("indirect call expects a pointer as callee: ") 1173 << ptrType; 1174 fnType = ptrType.getElementType(); 1175 } else { 1176 Operation *callee = 1177 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1178 if (!callee) 1179 return emitOpError() 1180 << "'" << calleeName.getValue() 1181 << "' does not reference a symbol in the current scope"; 1182 auto fn = dyn_cast<LLVMFuncOp>(callee); 1183 if (!fn) 1184 return emitOpError() << "'" << calleeName.getValue() 1185 << "' does not reference a valid LLVM function"; 1186 1187 fnType = fn.getFunctionType(); 1188 } 1189 1190 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1191 if (!funcType) 1192 return emitOpError("callee does not have a functional type: ") << fnType; 1193 1194 // Verify that the operand and result types match the callee. 1195 1196 if (!funcType.isVarArg() && 1197 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1198 return emitOpError() << "incorrect number of operands (" 1199 << (getNumOperands() - isIndirect) 1200 << ") for callee (expecting: " 1201 << funcType.getNumParams() << ")"; 1202 1203 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1204 return emitOpError() << "incorrect number of operands (" 1205 << (getNumOperands() - isIndirect) 1206 << ") for varargs callee (expecting at least: " 1207 << funcType.getNumParams() << ")"; 1208 1209 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1210 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1211 return emitOpError() << "operand type mismatch for operand " << i << ": " 1212 << getOperand(i + isIndirect).getType() 1213 << " != " << funcType.getParamType(i); 1214 1215 if (getNumResults() == 0 && 1216 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1217 return emitOpError() << "expected function call to produce a value"; 1218 1219 if (getNumResults() != 0 && 1220 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1221 return emitOpError() 1222 << "calling function with void result must not produce values"; 1223 1224 if (getNumResults() > 1) 1225 return emitOpError() 1226 << "expected LLVM function call to produce 0 or 1 result"; 1227 1228 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1229 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1230 << " != " << funcType.getReturnType(); 1231 1232 return success(); 1233 } 1234 1235 void CallOp::print(OpAsmPrinter &p) { 1236 auto callee = getCallee(); 1237 bool isDirect = callee.has_value(); 1238 1239 // Print the direct callee if present as a function attribute, or an indirect 1240 // callee (first operand) otherwise. 1241 p << ' '; 1242 if (isDirect) 1243 p.printSymbolName(callee.value()); 1244 else 1245 p << getOperand(0); 1246 1247 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1248 p << '(' << args << ')'; 1249 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1250 1251 // Reconstruct the function MLIR function type from operand and result types. 1252 p << " : "; 1253 p.printFunctionalType(args.getTypes(), getResultTypes()); 1254 } 1255 1256 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1257 // attribute-dict? `:` function-type 1258 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1259 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1260 Type type; 1261 SymbolRefAttr funcAttr; 1262 SMLoc trailingTypeLoc; 1263 1264 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1265 // case of an indirect call, there will be 1 operand before `(`. In case of a 1266 // direct call, there will be no operands and the parser will stop at the 1267 // function identifier without complaining. 1268 if (parser.parseOperandList(operands)) 1269 return failure(); 1270 bool isDirect = operands.empty(); 1271 1272 // Optionally parse a function identifier. 1273 if (isDirect) 1274 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1275 return failure(); 1276 1277 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1278 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1279 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1280 return failure(); 1281 1282 auto funcType = type.dyn_cast<FunctionType>(); 1283 if (!funcType) 1284 return parser.emitError(trailingTypeLoc, "expected function type"); 1285 if (funcType.getNumResults() > 1) 1286 return parser.emitError(trailingTypeLoc, 1287 "expected function with 0 or 1 result"); 1288 if (isDirect) { 1289 // Make sure types match. 1290 if (parser.resolveOperands(operands, funcType.getInputs(), 1291 parser.getNameLoc(), result.operands)) 1292 return failure(); 1293 if (funcType.getNumResults() != 0 && 1294 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1295 result.addTypes(funcType.getResults()); 1296 } else { 1297 Builder &builder = parser.getBuilder(); 1298 Type llvmResultType; 1299 if (funcType.getNumResults() == 0) { 1300 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1301 } else { 1302 llvmResultType = funcType.getResult(0); 1303 if (!isCompatibleType(llvmResultType)) 1304 return parser.emitError(trailingTypeLoc, 1305 "expected result to have LLVM type"); 1306 } 1307 1308 SmallVector<Type, 8> argTypes; 1309 argTypes.reserve(funcType.getNumInputs()); 1310 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1311 auto argType = funcType.getInput(i); 1312 if (!isCompatibleType(argType)) 1313 return parser.emitError(trailingTypeLoc, 1314 "expected LLVM types as inputs"); 1315 argTypes.push_back(argType); 1316 } 1317 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1318 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1319 1320 auto funcArguments = 1321 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1322 1323 // Make sure that the first operand (indirect callee) matches the wrapped 1324 // LLVM IR function type, and that the types of the other call operands 1325 // match the types of the function arguments. 1326 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1327 parser.resolveOperands(funcArguments, funcType.getInputs(), 1328 parser.getNameLoc(), result.operands)) 1329 return failure(); 1330 1331 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1332 result.addTypes(llvmResultType); 1333 } 1334 1335 return success(); 1336 } 1337 1338 //===----------------------------------------------------------------------===// 1339 // Printing/parsing for LLVM::ExtractElementOp. 1340 //===----------------------------------------------------------------------===// 1341 // Expects vector to be of wrapped LLVM vector type and position to be of 1342 // wrapped LLVM i32 type. 1343 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1344 Value vector, Value position, 1345 ArrayRef<NamedAttribute> attrs) { 1346 auto vectorType = vector.getType(); 1347 auto llvmType = LLVM::getVectorElementType(vectorType); 1348 build(b, result, llvmType, vector, position); 1349 result.addAttributes(attrs); 1350 } 1351 1352 void ExtractElementOp::print(OpAsmPrinter &p) { 1353 p << ' ' << getVector() << "[" << getPosition() << " : " 1354 << getPosition().getType() << "]"; 1355 p.printOptionalAttrDict((*this)->getAttrs()); 1356 p << " : " << getVector().getType(); 1357 } 1358 1359 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1360 // attribute-dict? `:` type 1361 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1362 OperationState &result) { 1363 SMLoc loc; 1364 OpAsmParser::UnresolvedOperand vector, position; 1365 Type type, positionType; 1366 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1367 parser.parseLSquare() || parser.parseOperand(position) || 1368 parser.parseColonType(positionType) || parser.parseRSquare() || 1369 parser.parseOptionalAttrDict(result.attributes) || 1370 parser.parseColonType(type) || 1371 parser.resolveOperand(vector, type, result.operands) || 1372 parser.resolveOperand(position, positionType, result.operands)) 1373 return failure(); 1374 if (!LLVM::isCompatibleVectorType(type)) 1375 return parser.emitError( 1376 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1377 result.addTypes(LLVM::getVectorElementType(type)); 1378 return success(); 1379 } 1380 1381 LogicalResult ExtractElementOp::verify() { 1382 Type vectorType = getVector().getType(); 1383 if (!LLVM::isCompatibleVectorType(vectorType)) 1384 return emitOpError("expected LLVM dialect-compatible vector type for " 1385 "operand #1, got") 1386 << vectorType; 1387 Type valueType = LLVM::getVectorElementType(vectorType); 1388 if (valueType != getRes().getType()) 1389 return emitOpError() << "Type mismatch: extracting from " << vectorType 1390 << " should produce " << valueType 1391 << " but this op returns " << getRes().getType(); 1392 return success(); 1393 } 1394 1395 //===----------------------------------------------------------------------===// 1396 // Printing/parsing for LLVM::ExtractValueOp. 1397 //===----------------------------------------------------------------------===// 1398 1399 void ExtractValueOp::print(OpAsmPrinter &p) { 1400 p << ' ' << getContainer() << getPosition(); 1401 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1402 p << " : " << getContainer().getType(); 1403 } 1404 1405 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1406 // `containerType`. Position is an integer array attribute where each value 1407 // is a zero-based position of the element in the aggregate type. Return the 1408 // resulting type wrapped in MLIR, or nullptr on error. 1409 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1410 Type containerType, 1411 ArrayAttr positionAttr, 1412 SMLoc attributeLoc, 1413 SMLoc typeLoc) { 1414 Type llvmType = containerType; 1415 if (!isCompatibleType(containerType)) 1416 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1417 1418 // Infer the element type from the structure type: iteratively step inside the 1419 // type by taking the element type, indexed by the position attribute for 1420 // structures. Check the position index before accessing, it is supposed to 1421 // be in bounds. 1422 for (Attribute subAttr : positionAttr) { 1423 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1424 if (!positionElementAttr) 1425 return parser.emitError(attributeLoc, 1426 "expected an array of integer literals"), 1427 nullptr; 1428 int position = positionElementAttr.getInt(); 1429 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1430 if (position < 0 || 1431 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1432 return parser.emitError(attributeLoc, "position out of bounds"), 1433 nullptr; 1434 llvmType = arrayType.getElementType(); 1435 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1436 if (position < 0 || 1437 static_cast<unsigned>(position) >= structType.getBody().size()) 1438 return parser.emitError(attributeLoc, "position out of bounds"), 1439 nullptr; 1440 llvmType = structType.getBody()[position]; 1441 } else { 1442 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1443 nullptr; 1444 } 1445 } 1446 return llvmType; 1447 } 1448 1449 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1450 // `containerType`. Returns null on failure. 1451 static Type getInsertExtractValueElementType(Type containerType, 1452 ArrayAttr positionAttr, 1453 Operation *op) { 1454 Type llvmType = containerType; 1455 if (!isCompatibleType(containerType)) { 1456 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1457 return {}; 1458 } 1459 1460 // Infer the element type from the structure type: iteratively step inside the 1461 // type by taking the element type, indexed by the position attribute for 1462 // structures. Check the position index before accessing, it is supposed to 1463 // be in bounds. 1464 for (Attribute subAttr : positionAttr) { 1465 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1466 if (!positionElementAttr) { 1467 op->emitOpError("expected an array of integer literals, got: ") 1468 << subAttr; 1469 return {}; 1470 } 1471 int position = positionElementAttr.getInt(); 1472 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1473 if (position < 0 || 1474 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1475 op->emitOpError("position out of bounds: ") << position; 1476 return {}; 1477 } 1478 llvmType = arrayType.getElementType(); 1479 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1480 if (position < 0 || 1481 static_cast<unsigned>(position) >= structType.getBody().size()) { 1482 op->emitOpError("position out of bounds") << position; 1483 return {}; 1484 } 1485 llvmType = structType.getBody()[position]; 1486 } else { 1487 op->emitOpError("expected LLVM IR structure/array type, got: ") 1488 << llvmType; 1489 return {}; 1490 } 1491 } 1492 return llvmType; 1493 } 1494 1495 // <operation> ::= `llvm.extractvalue` ssa-use 1496 // `[` integer-literal (`,` integer-literal)* `]` 1497 // attribute-dict? `:` type 1498 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1499 OpAsmParser::UnresolvedOperand container; 1500 Type containerType; 1501 ArrayAttr positionAttr; 1502 SMLoc attributeLoc, trailingTypeLoc; 1503 1504 if (parser.parseOperand(container) || 1505 parser.getCurrentLocation(&attributeLoc) || 1506 parser.parseAttribute(positionAttr, "position", result.attributes) || 1507 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1508 parser.getCurrentLocation(&trailingTypeLoc) || 1509 parser.parseType(containerType) || 1510 parser.resolveOperand(container, containerType, result.operands)) 1511 return failure(); 1512 1513 auto elementType = getInsertExtractValueElementType( 1514 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1515 if (!elementType) 1516 return failure(); 1517 1518 result.addTypes(elementType); 1519 return success(); 1520 } 1521 1522 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1523 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1524 OpFoldResult result = {}; 1525 while (insertValueOp) { 1526 if (getPosition() == insertValueOp.getPosition()) 1527 return insertValueOp.getValue(); 1528 unsigned min = 1529 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1530 // If one is fully prefix of the other, stop propagating back as it will 1531 // miss dependencies. For instance, %3 should not fold to %f0 in the 1532 // following example: 1533 // ``` 1534 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1535 // !llvm.array<4 x !llvm.array<4xf32>> 1536 // %2 = llvm.insertvalue %arr, %1[0] : 1537 // !llvm.array<4 x !llvm.array<4xf32>> 1538 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1539 // ``` 1540 if (getPosition().getValue().take_front(min) == 1541 insertValueOp.getPosition().getValue().take_front(min)) 1542 return result; 1543 1544 // If neither a prefix, nor the exact position, we can extract out of the 1545 // value being inserted into. Moreover, we can try again if that operand 1546 // is itself an insertvalue expression. 1547 getContainerMutable().assign(insertValueOp.getContainer()); 1548 result = getResult(); 1549 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1550 } 1551 return result; 1552 } 1553 1554 LogicalResult ExtractValueOp::verify() { 1555 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1556 getPositionAttr(), *this); 1557 if (!valueType) 1558 return failure(); 1559 1560 if (getRes().getType() != valueType) 1561 return emitOpError() << "Type mismatch: extracting from " 1562 << getContainer().getType() << " should produce " 1563 << valueType << " but this op returns " 1564 << getRes().getType(); 1565 return success(); 1566 } 1567 1568 //===----------------------------------------------------------------------===// 1569 // Printing/parsing for LLVM::InsertElementOp. 1570 //===----------------------------------------------------------------------===// 1571 1572 void InsertElementOp::print(OpAsmPrinter &p) { 1573 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1574 << getPosition().getType() << "]"; 1575 p.printOptionalAttrDict((*this)->getAttrs()); 1576 p << " : " << getVector().getType(); 1577 } 1578 1579 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1580 // attribute-dict? `:` type 1581 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1582 OperationState &result) { 1583 SMLoc loc; 1584 OpAsmParser::UnresolvedOperand vector, value, position; 1585 Type vectorType, positionType; 1586 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1587 parser.parseComma() || parser.parseOperand(vector) || 1588 parser.parseLSquare() || parser.parseOperand(position) || 1589 parser.parseColonType(positionType) || parser.parseRSquare() || 1590 parser.parseOptionalAttrDict(result.attributes) || 1591 parser.parseColonType(vectorType)) 1592 return failure(); 1593 1594 if (!LLVM::isCompatibleVectorType(vectorType)) 1595 return parser.emitError( 1596 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1597 Type valueType = LLVM::getVectorElementType(vectorType); 1598 if (!valueType) 1599 return failure(); 1600 1601 if (parser.resolveOperand(vector, vectorType, result.operands) || 1602 parser.resolveOperand(value, valueType, result.operands) || 1603 parser.resolveOperand(position, positionType, result.operands)) 1604 return failure(); 1605 1606 result.addTypes(vectorType); 1607 return success(); 1608 } 1609 1610 LogicalResult InsertElementOp::verify() { 1611 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1612 if (valueType != getValue().getType()) 1613 return emitOpError() << "Type mismatch: cannot insert " 1614 << getValue().getType() << " into " 1615 << getVector().getType(); 1616 return success(); 1617 } 1618 1619 //===----------------------------------------------------------------------===// 1620 // Printing/parsing for LLVM::InsertValueOp. 1621 //===----------------------------------------------------------------------===// 1622 1623 void InsertValueOp::print(OpAsmPrinter &p) { 1624 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1625 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1626 p << " : " << getContainer().getType(); 1627 } 1628 1629 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1630 // `[` integer-literal (`,` integer-literal)* `]` 1631 // attribute-dict? `:` type 1632 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1633 OpAsmParser::UnresolvedOperand container, value; 1634 Type containerType; 1635 ArrayAttr positionAttr; 1636 SMLoc attributeLoc, trailingTypeLoc; 1637 1638 if (parser.parseOperand(value) || parser.parseComma() || 1639 parser.parseOperand(container) || 1640 parser.getCurrentLocation(&attributeLoc) || 1641 parser.parseAttribute(positionAttr, "position", result.attributes) || 1642 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1643 parser.getCurrentLocation(&trailingTypeLoc) || 1644 parser.parseType(containerType)) 1645 return failure(); 1646 1647 auto valueType = getInsertExtractValueElementType( 1648 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1649 if (!valueType) 1650 return failure(); 1651 1652 if (parser.resolveOperand(container, containerType, result.operands) || 1653 parser.resolveOperand(value, valueType, result.operands)) 1654 return failure(); 1655 1656 result.addTypes(containerType); 1657 return success(); 1658 } 1659 1660 LogicalResult InsertValueOp::verify() { 1661 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1662 getPositionAttr(), *this); 1663 if (!valueType) 1664 return failure(); 1665 1666 if (getValue().getType() != valueType) 1667 return emitOpError() << "Type mismatch: cannot insert " 1668 << getValue().getType() << " into " 1669 << getContainer().getType(); 1670 1671 return success(); 1672 } 1673 1674 //===----------------------------------------------------------------------===// 1675 // Printing, parsing and verification for LLVM::ReturnOp. 1676 //===----------------------------------------------------------------------===// 1677 1678 LogicalResult ReturnOp::verify() { 1679 if (getNumOperands() > 1) 1680 return emitOpError("expected at most 1 operand"); 1681 1682 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1683 Type expectedType = parent.getFunctionType().getReturnType(); 1684 if (expectedType.isa<LLVMVoidType>()) { 1685 if (getNumOperands() == 0) 1686 return success(); 1687 InFlightDiagnostic diag = emitOpError("expected no operands"); 1688 diag.attachNote(parent->getLoc()) << "when returning from function"; 1689 return diag; 1690 } 1691 if (getNumOperands() == 0) { 1692 if (expectedType.isa<LLVMVoidType>()) 1693 return success(); 1694 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1695 diag.attachNote(parent->getLoc()) << "when returning from function"; 1696 return diag; 1697 } 1698 if (expectedType != getOperand(0).getType()) { 1699 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1700 diag.attachNote(parent->getLoc()) << "when returning from function"; 1701 return diag; 1702 } 1703 } 1704 return success(); 1705 } 1706 1707 //===----------------------------------------------------------------------===// 1708 // ResumeOp 1709 //===----------------------------------------------------------------------===// 1710 1711 LogicalResult ResumeOp::verify() { 1712 if (!getValue().getDefiningOp<LandingpadOp>()) 1713 return emitOpError("expects landingpad value as operand"); 1714 // No check for personality of function - landingpad op verifies it. 1715 return success(); 1716 } 1717 1718 //===----------------------------------------------------------------------===// 1719 // Verifier for LLVM::AddressOfOp. 1720 //===----------------------------------------------------------------------===// 1721 1722 template <typename OpTy> 1723 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1724 Operation *module = parent; 1725 while (module && !satisfiesLLVMModule(module)) 1726 module = module->getParentOp(); 1727 assert(module && "unexpected operation outside of a module"); 1728 return dyn_cast_or_null<OpTy>( 1729 mlir::SymbolTable::lookupSymbolIn(module, name)); 1730 } 1731 1732 GlobalOp AddressOfOp::getGlobal() { 1733 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1734 getGlobalName()); 1735 } 1736 1737 LLVMFuncOp AddressOfOp::getFunction() { 1738 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1739 getGlobalName()); 1740 } 1741 1742 LogicalResult AddressOfOp::verify() { 1743 auto global = getGlobal(); 1744 auto function = getFunction(); 1745 if (!global && !function) 1746 return emitOpError( 1747 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1748 1749 LLVMPointerType type = getType(); 1750 if (global && global.getAddrSpace() != type.getAddressSpace()) 1751 return emitOpError("pointer address space must match address space of the " 1752 "referenced global"); 1753 1754 if (type.isOpaque()) 1755 return success(); 1756 1757 if (global && type.getElementType() != global.getType()) 1758 return emitOpError( 1759 "the type must be a pointer to the type of the referenced global"); 1760 1761 if (function && type.getElementType() != function.getFunctionType()) 1762 return emitOpError( 1763 "the type must be a pointer to the type of the referenced function"); 1764 1765 return success(); 1766 } 1767 1768 //===----------------------------------------------------------------------===// 1769 // Builder, printer and verifier for LLVM::GlobalOp. 1770 //===----------------------------------------------------------------------===// 1771 1772 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1773 bool isConstant, Linkage linkage, StringRef name, 1774 Attribute value, uint64_t alignment, unsigned addrSpace, 1775 bool dsoLocal, bool threadLocal, 1776 ArrayRef<NamedAttribute> attrs) { 1777 result.addAttribute(getSymNameAttrName(result.name), 1778 builder.getStringAttr(name)); 1779 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 1780 if (isConstant) 1781 result.addAttribute(getConstantAttrName(result.name), 1782 builder.getUnitAttr()); 1783 if (value) 1784 result.addAttribute(getValueAttrName(result.name), value); 1785 if (dsoLocal) 1786 result.addAttribute(getDsoLocalAttrName(result.name), 1787 builder.getUnitAttr()); 1788 if (threadLocal) 1789 result.addAttribute(getThreadLocal_AttrName(result.name), 1790 builder.getUnitAttr()); 1791 1792 // Only add an alignment attribute if the "alignment" input 1793 // is different from 0. The value must also be a power of two, but 1794 // this is tested in GlobalOp::verify, not here. 1795 if (alignment != 0) 1796 result.addAttribute(getAlignmentAttrName(result.name), 1797 builder.getI64IntegerAttr(alignment)); 1798 1799 result.addAttribute(getLinkageAttrName(result.name), 1800 LinkageAttr::get(builder.getContext(), linkage)); 1801 if (addrSpace != 0) 1802 result.addAttribute(getAddrSpaceAttrName(result.name), 1803 builder.getI32IntegerAttr(addrSpace)); 1804 result.attributes.append(attrs.begin(), attrs.end()); 1805 result.addRegion(); 1806 } 1807 1808 void GlobalOp::print(OpAsmPrinter &p) { 1809 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1810 if (auto unnamedAddr = getUnnamedAddr()) { 1811 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1812 if (!str.empty()) 1813 p << str << ' '; 1814 } 1815 if (getThreadLocal_()) 1816 p << "thread_local "; 1817 if (getConstant()) 1818 p << "constant "; 1819 p.printSymbolName(getSymName()); 1820 p << '('; 1821 if (auto value = getValueOrNull()) 1822 p.printAttribute(value); 1823 p << ')'; 1824 // Note that the alignment attribute is printed using the 1825 // default syntax here, even though it is an inherent attribute 1826 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1827 p.printOptionalAttrDict( 1828 (*this)->getAttrs(), 1829 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), 1830 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), 1831 getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); 1832 1833 // Print the trailing type unless it's a string global. 1834 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1835 return; 1836 p << " : " << getType(); 1837 1838 Region &initializer = getInitializerRegion(); 1839 if (!initializer.empty()) { 1840 p << ' '; 1841 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1842 } 1843 } 1844 1845 // Parses one of the keywords provided in the list `keywords` and returns the 1846 // position of the parsed keyword in the list. If none of the keywords from the 1847 // list is parsed, returns -1. 1848 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1849 ArrayRef<StringRef> keywords) { 1850 for (const auto &en : llvm::enumerate(keywords)) { 1851 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1852 return en.index(); 1853 } 1854 return -1; 1855 } 1856 1857 namespace { 1858 template <typename Ty> 1859 struct EnumTraits {}; 1860 1861 #define REGISTER_ENUM_TYPE(Ty) \ 1862 template <> \ 1863 struct EnumTraits<Ty> { \ 1864 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1865 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1866 } 1867 1868 REGISTER_ENUM_TYPE(Linkage); 1869 REGISTER_ENUM_TYPE(UnnamedAddr); 1870 REGISTER_ENUM_TYPE(CConv); 1871 } // namespace 1872 1873 /// Parse an enum from the keyword, or default to the provided default value. 1874 /// The return type is the enum type by default, unless overriden with the 1875 /// second template argument. 1876 template <typename EnumTy, typename RetTy = EnumTy> 1877 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1878 OperationState &result, 1879 EnumTy defaultValue) { 1880 SmallVector<StringRef, 10> names; 1881 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1882 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1883 1884 int index = parseOptionalKeywordAlternative(parser, names); 1885 if (index == -1) 1886 return static_cast<RetTy>(defaultValue); 1887 return static_cast<RetTy>(index); 1888 } 1889 1890 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1891 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1892 // align ::= `align` `=` UINT64 1893 // 1894 // The type can be omitted for string attributes, in which case it will be 1895 // inferred from the value of the string as [strlen(value) x i8]. 1896 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1897 MLIRContext *ctx = parser.getContext(); 1898 // Parse optional linkage, default to External. 1899 result.addAttribute(getLinkageAttrName(result.name), 1900 LLVM::LinkageAttr::get( 1901 ctx, parseOptionalLLVMKeyword<Linkage>( 1902 parser, result, LLVM::Linkage::External))); 1903 1904 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 1905 result.addAttribute(getThreadLocal_AttrName(result.name), 1906 parser.getBuilder().getUnitAttr()); 1907 1908 // Parse optional UnnamedAddr, default to None. 1909 result.addAttribute(getUnnamedAddrAttrName(result.name), 1910 parser.getBuilder().getI64IntegerAttr( 1911 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1912 parser, result, LLVM::UnnamedAddr::None))); 1913 1914 if (succeeded(parser.parseOptionalKeyword("constant"))) 1915 result.addAttribute(getConstantAttrName(result.name), 1916 parser.getBuilder().getUnitAttr()); 1917 1918 StringAttr name; 1919 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 1920 result.attributes) || 1921 parser.parseLParen()) 1922 return failure(); 1923 1924 Attribute value; 1925 if (parser.parseOptionalRParen()) { 1926 if (parser.parseAttribute(value, getValueAttrName(result.name), 1927 result.attributes) || 1928 parser.parseRParen()) 1929 return failure(); 1930 } 1931 1932 SmallVector<Type, 1> types; 1933 if (parser.parseOptionalAttrDict(result.attributes) || 1934 parser.parseOptionalColonTypeList(types)) 1935 return failure(); 1936 1937 if (types.size() > 1) 1938 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1939 1940 Region &initRegion = *result.addRegion(); 1941 if (types.empty()) { 1942 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1943 MLIRContext *context = parser.getContext(); 1944 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1945 strAttr.getValue().size()); 1946 types.push_back(arrayType); 1947 } else { 1948 return parser.emitError(parser.getNameLoc(), 1949 "type can only be omitted for string globals"); 1950 } 1951 } else { 1952 OptionalParseResult parseResult = 1953 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1954 /*argTypes=*/{}); 1955 if (parseResult.hasValue() && failed(*parseResult)) 1956 return failure(); 1957 } 1958 1959 result.addAttribute(getGlobalTypeAttrName(result.name), 1960 TypeAttr::get(types[0])); 1961 return success(); 1962 } 1963 1964 static bool isZeroAttribute(Attribute value) { 1965 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1966 return intValue.getValue().isNullValue(); 1967 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1968 return fpValue.getValue().isZero(); 1969 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1970 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1971 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1972 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1973 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1974 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1975 return false; 1976 } 1977 1978 LogicalResult GlobalOp::verify() { 1979 if (!LLVMPointerType::isValidElementType(getType())) 1980 return emitOpError( 1981 "expects type to be a valid element type for an LLVM pointer"); 1982 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1983 return emitOpError("must appear at the module level"); 1984 1985 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1986 auto type = getType().dyn_cast<LLVMArrayType>(); 1987 IntegerType elementType = 1988 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1989 if (!elementType || elementType.getWidth() != 8 || 1990 type.getNumElements() != strAttr.getValue().size()) 1991 return emitOpError( 1992 "requires an i8 array type of the length equal to that of the string " 1993 "attribute"); 1994 } 1995 1996 if (getLinkage() == Linkage::Common) { 1997 if (Attribute value = getValueOrNull()) { 1998 if (!isZeroAttribute(value)) { 1999 return emitOpError() 2000 << "expected zero value for '" 2001 << stringifyLinkage(Linkage::Common) << "' linkage"; 2002 } 2003 } 2004 } 2005 2006 if (getLinkage() == Linkage::Appending) { 2007 if (!getType().isa<LLVMArrayType>()) { 2008 return emitOpError() << "expected array type for '" 2009 << stringifyLinkage(Linkage::Appending) 2010 << "' linkage"; 2011 } 2012 } 2013 2014 Optional<uint64_t> alignAttr = getAlignment(); 2015 if (alignAttr.has_value()) { 2016 uint64_t value = alignAttr.value(); 2017 if (!llvm::isPowerOf2_64(value)) 2018 return emitError() << "alignment attribute is not a power of 2"; 2019 } 2020 2021 return success(); 2022 } 2023 2024 LogicalResult GlobalOp::verifyRegions() { 2025 if (Block *b = getInitializerBlock()) { 2026 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 2027 if (ret.operand_type_begin() == ret.operand_type_end()) 2028 return emitOpError("initializer region cannot return void"); 2029 if (*ret.operand_type_begin() != getType()) 2030 return emitOpError("initializer region type ") 2031 << *ret.operand_type_begin() << " does not match global type " 2032 << getType(); 2033 2034 for (Operation &op : *b) { 2035 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 2036 if (!iface || !iface.hasNoEffect()) 2037 return op.emitError() 2038 << "ops with side effects not allowed in global initializers"; 2039 } 2040 2041 if (getValueOrNull()) 2042 return emitOpError("cannot have both initializer value and region"); 2043 } 2044 2045 return success(); 2046 } 2047 2048 //===----------------------------------------------------------------------===// 2049 // LLVM::GlobalCtorsOp 2050 //===----------------------------------------------------------------------===// 2051 2052 LogicalResult 2053 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2054 for (Attribute ctor : getCtors()) { 2055 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 2056 symbolTable))) 2057 return failure(); 2058 } 2059 return success(); 2060 } 2061 2062 LogicalResult GlobalCtorsOp::verify() { 2063 if (getCtors().size() != getPriorities().size()) 2064 return emitError( 2065 "mismatch between the number of ctors and the number of priorities"); 2066 return success(); 2067 } 2068 2069 //===----------------------------------------------------------------------===// 2070 // LLVM::GlobalDtorsOp 2071 //===----------------------------------------------------------------------===// 2072 2073 LogicalResult 2074 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2075 for (Attribute dtor : getDtors()) { 2076 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 2077 symbolTable))) 2078 return failure(); 2079 } 2080 return success(); 2081 } 2082 2083 LogicalResult GlobalDtorsOp::verify() { 2084 if (getDtors().size() != getPriorities().size()) 2085 return emitError( 2086 "mismatch between the number of dtors and the number of priorities"); 2087 return success(); 2088 } 2089 2090 //===----------------------------------------------------------------------===// 2091 // Printing/parsing for LLVM::ShuffleVectorOp. 2092 //===----------------------------------------------------------------------===// 2093 // Expects vector to be of wrapped LLVM vector type and position to be of 2094 // wrapped LLVM i32 type. 2095 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 2096 Value v1, Value v2, ArrayAttr mask, 2097 ArrayRef<NamedAttribute> attrs) { 2098 auto containerType = v1.getType(); 2099 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), 2100 mask.size(), 2101 LLVM::isScalableVectorType(containerType)); 2102 build(b, result, vType, v1, v2, mask); 2103 result.addAttributes(attrs); 2104 } 2105 2106 void ShuffleVectorOp::print(OpAsmPrinter &p) { 2107 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 2108 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 2109 p << " : " << getV1().getType() << ", " << getV2().getType(); 2110 } 2111 2112 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 2113 // `[` integer-literal (`,` integer-literal)* `]` 2114 // attribute-dict? `:` type 2115 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 2116 OperationState &result) { 2117 SMLoc loc; 2118 OpAsmParser::UnresolvedOperand v1, v2; 2119 ArrayAttr maskAttr; 2120 Type typeV1, typeV2; 2121 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 2122 parser.parseComma() || parser.parseOperand(v2) || 2123 parser.parseAttribute(maskAttr, "mask", result.attributes) || 2124 parser.parseOptionalAttrDict(result.attributes) || 2125 parser.parseColonType(typeV1) || parser.parseComma() || 2126 parser.parseType(typeV2) || 2127 parser.resolveOperand(v1, typeV1, result.operands) || 2128 parser.resolveOperand(v2, typeV2, result.operands)) 2129 return failure(); 2130 if (!LLVM::isCompatibleVectorType(typeV1)) 2131 return parser.emitError( 2132 loc, "expected LLVM IR dialect vector type for operand #1"); 2133 auto vType = 2134 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 2135 LLVM::isScalableVectorType(typeV1)); 2136 result.addTypes(vType); 2137 return success(); 2138 } 2139 2140 LogicalResult ShuffleVectorOp::verify() { 2141 Type type1 = getV1().getType(); 2142 Type type2 = getV2().getType(); 2143 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 2144 return emitOpError("expected matching LLVM IR Dialect element types"); 2145 if (LLVM::isScalableVectorType(type1)) 2146 if (llvm::any_of(getMask(), [](Attribute attr) { 2147 return attr.cast<IntegerAttr>().getInt() != 0; 2148 })) 2149 return emitOpError("expected a splat operation for scalable vectors"); 2150 return success(); 2151 } 2152 2153 //===----------------------------------------------------------------------===// 2154 // Implementations for LLVM::LLVMFuncOp. 2155 //===----------------------------------------------------------------------===// 2156 2157 // Add the entry block to the function. 2158 Block *LLVMFuncOp::addEntryBlock() { 2159 assert(empty() && "function already has an entry block"); 2160 2161 auto *entry = new Block; 2162 push_back(entry); 2163 2164 // FIXME: Allow passing in proper locations for the entry arguments. 2165 LLVMFunctionType type = getFunctionType(); 2166 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2167 entry->addArgument(type.getParamType(i), getLoc()); 2168 return entry; 2169 } 2170 2171 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2172 StringRef name, Type type, LLVM::Linkage linkage, 2173 bool dsoLocal, CConv cconv, 2174 ArrayRef<NamedAttribute> attrs, 2175 ArrayRef<DictionaryAttr> argAttrs) { 2176 result.addRegion(); 2177 result.addAttribute(SymbolTable::getSymbolAttrName(), 2178 builder.getStringAttr(name)); 2179 result.addAttribute(getFunctionTypeAttrName(result.name), 2180 TypeAttr::get(type)); 2181 result.addAttribute(getLinkageAttrName(result.name), 2182 LinkageAttr::get(builder.getContext(), linkage)); 2183 result.addAttribute(getCConvAttrName(result.name), 2184 CConvAttr::get(builder.getContext(), cconv)); 2185 result.attributes.append(attrs.begin(), attrs.end()); 2186 if (dsoLocal) 2187 result.addAttribute("dso_local", builder.getUnitAttr()); 2188 if (argAttrs.empty()) 2189 return; 2190 2191 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 2192 "expected as many argument attribute lists as arguments"); 2193 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 2194 /*resultAttrs=*/llvm::None); 2195 } 2196 2197 // Builds an LLVM function type from the given lists of input and output types. 2198 // Returns a null type if any of the types provided are non-LLVM types, or if 2199 // there is more than one output type. 2200 static Type 2201 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2202 ArrayRef<Type> outputs, 2203 function_interface_impl::VariadicFlag variadicFlag) { 2204 Builder &b = parser.getBuilder(); 2205 if (outputs.size() > 1) { 2206 parser.emitError(loc, "failed to construct function type: expected zero or " 2207 "one function result"); 2208 return {}; 2209 } 2210 2211 // Convert inputs to LLVM types, exit early on error. 2212 SmallVector<Type, 4> llvmInputs; 2213 for (auto t : inputs) { 2214 if (!isCompatibleType(t)) { 2215 parser.emitError(loc, "failed to construct function type: expected LLVM " 2216 "type for function arguments"); 2217 return {}; 2218 } 2219 llvmInputs.push_back(t); 2220 } 2221 2222 // No output is denoted as "void" in LLVM type system. 2223 Type llvmOutput = 2224 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2225 if (!isCompatibleType(llvmOutput)) { 2226 parser.emitError(loc, "failed to construct function type: expected LLVM " 2227 "type for function results") 2228 << llvmOutput; 2229 return {}; 2230 } 2231 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2232 variadicFlag.isVariadic()); 2233 } 2234 2235 // Parses an LLVM function. 2236 // 2237 // operation ::= `llvm.func` linkage? cconv? function-signature 2238 // function-attributes? 2239 // function-body 2240 // 2241 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2242 // Default to external linkage if no keyword is provided. 2243 result.addAttribute( 2244 getLinkageAttrName(result.name), 2245 LinkageAttr::get(parser.getContext(), 2246 parseOptionalLLVMKeyword<Linkage>( 2247 parser, result, LLVM::Linkage::External))); 2248 2249 // Default to C Calling Convention if no keyword is provided. 2250 result.addAttribute( 2251 getCConvAttrName(result.name), 2252 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 2253 parser, result, LLVM::CConv::C))); 2254 2255 StringAttr nameAttr; 2256 SmallVector<OpAsmParser::Argument> entryArgs; 2257 SmallVector<DictionaryAttr> resultAttrs; 2258 SmallVector<Type> resultTypes; 2259 bool isVariadic; 2260 2261 auto signatureLocation = parser.getCurrentLocation(); 2262 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2263 result.attributes) || 2264 function_interface_impl::parseFunctionSignature( 2265 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, 2266 resultAttrs)) 2267 return failure(); 2268 2269 SmallVector<Type> argTypes; 2270 for (auto &arg : entryArgs) 2271 argTypes.push_back(arg.type); 2272 auto type = 2273 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2274 function_interface_impl::VariadicFlag(isVariadic)); 2275 if (!type) 2276 return failure(); 2277 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2278 TypeAttr::get(type)); 2279 2280 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2281 return failure(); 2282 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2283 entryArgs, resultAttrs); 2284 2285 auto *body = result.addRegion(); 2286 OptionalParseResult parseResult = 2287 parser.parseOptionalRegion(*body, entryArgs); 2288 return failure(parseResult.hasValue() && failed(*parseResult)); 2289 } 2290 2291 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2292 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2293 // the external linkage since it is the default value. 2294 void LLVMFuncOp::print(OpAsmPrinter &p) { 2295 p << ' '; 2296 if (getLinkage() != LLVM::Linkage::External) 2297 p << stringifyLinkage(getLinkage()) << ' '; 2298 if (getCConv() != LLVM::CConv::C) 2299 p << stringifyCConv(getCConv()) << ' '; 2300 2301 p.printSymbolName(getName()); 2302 2303 LLVMFunctionType fnType = getFunctionType(); 2304 SmallVector<Type, 8> argTypes; 2305 SmallVector<Type, 1> resTypes; 2306 argTypes.reserve(fnType.getNumParams()); 2307 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2308 argTypes.push_back(fnType.getParamType(i)); 2309 2310 Type returnType = fnType.getReturnType(); 2311 if (!returnType.isa<LLVMVoidType>()) 2312 resTypes.push_back(returnType); 2313 2314 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2315 isVarArg(), resTypes); 2316 function_interface_impl::printFunctionAttributes( 2317 p, *this, argTypes.size(), resTypes.size(), 2318 {getLinkageAttrName(), getCConvAttrName()}); 2319 2320 // Print the body if this is not an external function. 2321 Region &body = getBody(); 2322 if (!body.empty()) { 2323 p << ' '; 2324 p.printRegion(body, /*printEntryBlockArgs=*/false, 2325 /*printBlockTerminators=*/true); 2326 } 2327 } 2328 2329 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2330 // - functions don't have 'common' linkage 2331 // - external functions have 'external' or 'extern_weak' linkage; 2332 // - vararg is (currently) only supported for external functions; 2333 LogicalResult LLVMFuncOp::verify() { 2334 if (getLinkage() == LLVM::Linkage::Common) 2335 return emitOpError() << "functions cannot have '" 2336 << stringifyLinkage(LLVM::Linkage::Common) 2337 << "' linkage"; 2338 2339 // Check to see if this function has a void return with a result attribute to 2340 // it. It isn't clear what semantics we would assign to that. 2341 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2342 !getResultAttrs(0).empty()) { 2343 return emitOpError() 2344 << "cannot attach result attributes to functions with a void return"; 2345 } 2346 2347 if (isExternal()) { 2348 if (getLinkage() != LLVM::Linkage::External && 2349 getLinkage() != LLVM::Linkage::ExternWeak) 2350 return emitOpError() << "external functions must have '" 2351 << stringifyLinkage(LLVM::Linkage::External) 2352 << "' or '" 2353 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2354 << "' linkage"; 2355 return success(); 2356 } 2357 2358 return success(); 2359 } 2360 2361 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2362 /// - entry block arguments are of LLVM types. 2363 LogicalResult LLVMFuncOp::verifyRegions() { 2364 if (isExternal()) 2365 return success(); 2366 2367 unsigned numArguments = getFunctionType().getNumParams(); 2368 Block &entryBlock = front(); 2369 for (unsigned i = 0; i < numArguments; ++i) { 2370 Type argType = entryBlock.getArgument(i).getType(); 2371 if (!isCompatibleType(argType)) 2372 return emitOpError("entry block argument #") 2373 << i << " is not of LLVM type"; 2374 } 2375 2376 return success(); 2377 } 2378 2379 //===----------------------------------------------------------------------===// 2380 // Verification for LLVM::ConstantOp. 2381 //===----------------------------------------------------------------------===// 2382 2383 LogicalResult LLVM::ConstantOp::verify() { 2384 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2385 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2386 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2387 !arrayType.getElementType().isInteger(8)) { 2388 return emitOpError() << "expected array type of " 2389 << sAttr.getValue().size() 2390 << " i8 elements for the string constant"; 2391 } 2392 return success(); 2393 } 2394 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2395 if (structType.getBody().size() != 2 || 2396 structType.getBody()[0] != structType.getBody()[1]) { 2397 return emitError() << "expected struct type with two elements of the " 2398 "same type, the type of a complex constant"; 2399 } 2400 2401 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2402 if (!arrayAttr || arrayAttr.size() != 2 || 2403 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2404 return emitOpError() << "expected array attribute with two elements, " 2405 "representing a complex constant"; 2406 } 2407 2408 Type elementType = structType.getBody()[0]; 2409 if (!elementType 2410 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2411 return emitError() 2412 << "expected struct element types to be floating point type or " 2413 "integer type"; 2414 } 2415 return success(); 2416 } 2417 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2418 return emitOpError() 2419 << "only supports integer, float, string or elements attributes"; 2420 return success(); 2421 } 2422 2423 // Constant op constant-folds to its value. 2424 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2425 2426 //===----------------------------------------------------------------------===// 2427 // Utility functions for parsing atomic ops 2428 //===----------------------------------------------------------------------===// 2429 2430 // Helper function to parse a keyword into the specified attribute named by 2431 // `attrName`. The keyword must match one of the string values defined by the 2432 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2433 // state. 2434 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2435 StringRef attrName) { 2436 SMLoc loc; 2437 StringRef keyword; 2438 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2439 return failure(); 2440 2441 // Replace the keyword `keyword` with an integer attribute. 2442 auto kind = symbolizeAtomicBinOp(keyword); 2443 if (!kind) { 2444 return parser.emitError(loc) 2445 << "'" << keyword << "' is an incorrect value of the '" << attrName 2446 << "' attribute"; 2447 } 2448 2449 auto value = static_cast<int64_t>(*kind); 2450 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2451 result.addAttribute(attrName, attr); 2452 2453 return success(); 2454 } 2455 2456 // Helper function to parse a keyword into the specified attribute named by 2457 // `attrName`. The keyword must match one of the string values defined by the 2458 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2459 // state. 2460 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2461 OperationState &result, 2462 StringRef attrName) { 2463 SMLoc loc; 2464 StringRef ordering; 2465 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2466 return failure(); 2467 2468 // Replace the keyword `ordering` with an integer attribute. 2469 auto kind = symbolizeAtomicOrdering(ordering); 2470 if (!kind) { 2471 return parser.emitError(loc) 2472 << "'" << ordering << "' is an incorrect value of the '" << attrName 2473 << "' attribute"; 2474 } 2475 2476 auto value = static_cast<int64_t>(*kind); 2477 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2478 result.addAttribute(attrName, attr); 2479 2480 return success(); 2481 } 2482 2483 //===----------------------------------------------------------------------===// 2484 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2485 //===----------------------------------------------------------------------===// 2486 2487 void AtomicRMWOp::print(OpAsmPrinter &p) { 2488 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2489 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2490 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2491 p << " : " << getRes().getType(); 2492 } 2493 2494 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2495 // attribute-dict? `:` type 2496 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2497 Type type; 2498 OpAsmParser::UnresolvedOperand ptr, val; 2499 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2500 parser.parseComma() || parser.parseOperand(val) || 2501 parseAtomicOrdering(parser, result, "ordering") || 2502 parser.parseOptionalAttrDict(result.attributes) || 2503 parser.parseColonType(type) || 2504 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2505 result.operands) || 2506 parser.resolveOperand(val, type, result.operands)) 2507 return failure(); 2508 2509 result.addTypes(type); 2510 return success(); 2511 } 2512 2513 LogicalResult AtomicRMWOp::verify() { 2514 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2515 auto valType = getVal().getType(); 2516 if (valType != ptrType.getElementType()) 2517 return emitOpError("expected LLVM IR element type for operand #0 to " 2518 "match type for operand #1"); 2519 auto resType = getRes().getType(); 2520 if (resType != valType) 2521 return emitOpError( 2522 "expected LLVM IR result type to match type for operand #1"); 2523 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2524 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2525 return emitOpError("expected LLVM IR floating point type"); 2526 } else if (getBinOp() == AtomicBinOp::xchg) { 2527 auto intType = valType.dyn_cast<IntegerType>(); 2528 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2529 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2530 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2531 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2532 !valType.isa<Float64Type>()) 2533 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2534 } else { 2535 auto intType = valType.dyn_cast<IntegerType>(); 2536 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2537 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2538 intBitWidth != 64) 2539 return emitOpError("expected LLVM IR integer type"); 2540 } 2541 2542 if (static_cast<unsigned>(getOrdering()) < 2543 static_cast<unsigned>(AtomicOrdering::monotonic)) 2544 return emitOpError() << "expected at least '" 2545 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2546 << "' ordering"; 2547 2548 return success(); 2549 } 2550 2551 //===----------------------------------------------------------------------===// 2552 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2553 //===----------------------------------------------------------------------===// 2554 2555 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2556 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2557 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2558 << stringifyAtomicOrdering(getFailureOrdering()); 2559 p.printOptionalAttrDict((*this)->getAttrs(), 2560 {"success_ordering", "failure_ordering"}); 2561 p << " : " << getVal().getType(); 2562 } 2563 2564 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2565 // keyword keyword attribute-dict? `:` type 2566 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2567 OperationState &result) { 2568 auto &builder = parser.getBuilder(); 2569 Type type; 2570 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2571 if (parser.parseOperand(ptr) || parser.parseComma() || 2572 parser.parseOperand(cmp) || parser.parseComma() || 2573 parser.parseOperand(val) || 2574 parseAtomicOrdering(parser, result, "success_ordering") || 2575 parseAtomicOrdering(parser, result, "failure_ordering") || 2576 parser.parseOptionalAttrDict(result.attributes) || 2577 parser.parseColonType(type) || 2578 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2579 result.operands) || 2580 parser.resolveOperand(cmp, type, result.operands) || 2581 parser.resolveOperand(val, type, result.operands)) 2582 return failure(); 2583 2584 auto boolType = IntegerType::get(builder.getContext(), 1); 2585 auto resultType = 2586 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2587 result.addTypes(resultType); 2588 2589 return success(); 2590 } 2591 2592 LogicalResult AtomicCmpXchgOp::verify() { 2593 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2594 if (!ptrType) 2595 return emitOpError("expected LLVM IR pointer type for operand #0"); 2596 auto cmpType = getCmp().getType(); 2597 auto valType = getVal().getType(); 2598 if (cmpType != ptrType.getElementType() || cmpType != valType) 2599 return emitOpError("expected LLVM IR element type for operand #0 to " 2600 "match type for all other operands"); 2601 auto intType = valType.dyn_cast<IntegerType>(); 2602 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2603 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2604 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2605 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2606 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2607 return emitOpError("unexpected LLVM IR type"); 2608 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2609 getFailureOrdering() < AtomicOrdering::monotonic) 2610 return emitOpError("ordering must be at least 'monotonic'"); 2611 if (getFailureOrdering() == AtomicOrdering::release || 2612 getFailureOrdering() == AtomicOrdering::acq_rel) 2613 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2614 return success(); 2615 } 2616 2617 //===----------------------------------------------------------------------===// 2618 // Printer, parser and verifier for LLVM::FenceOp. 2619 //===----------------------------------------------------------------------===// 2620 2621 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2622 // attribute-dict? 2623 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2624 StringAttr sScope; 2625 StringRef syncscopeKeyword = "syncscope"; 2626 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2627 if (parser.parseLParen() || 2628 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2629 parser.parseRParen()) 2630 return failure(); 2631 } else { 2632 result.addAttribute(syncscopeKeyword, 2633 parser.getBuilder().getStringAttr("")); 2634 } 2635 if (parseAtomicOrdering(parser, result, "ordering") || 2636 parser.parseOptionalAttrDict(result.attributes)) 2637 return failure(); 2638 return success(); 2639 } 2640 2641 void FenceOp::print(OpAsmPrinter &p) { 2642 StringRef syncscopeKeyword = "syncscope"; 2643 p << ' '; 2644 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2645 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2646 p << stringifyAtomicOrdering(getOrdering()); 2647 } 2648 2649 LogicalResult FenceOp::verify() { 2650 if (getOrdering() == AtomicOrdering::not_atomic || 2651 getOrdering() == AtomicOrdering::unordered || 2652 getOrdering() == AtomicOrdering::monotonic) 2653 return emitOpError("can be given only acquire, release, acq_rel, " 2654 "and seq_cst orderings"); 2655 return success(); 2656 } 2657 2658 //===----------------------------------------------------------------------===// 2659 // Folder for LLVM::BitcastOp 2660 //===----------------------------------------------------------------------===// 2661 2662 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2663 // bitcast(x : T0, T0) -> x 2664 if (getArg().getType() == getType()) 2665 return getArg(); 2666 // bitcast(bitcast(x : T0, T1), T0) -> x 2667 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2668 if (prev.getArg().getType() == getType()) 2669 return prev.getArg(); 2670 return {}; 2671 } 2672 2673 //===----------------------------------------------------------------------===// 2674 // Folder for LLVM::AddrSpaceCastOp 2675 //===----------------------------------------------------------------------===// 2676 2677 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2678 // addrcast(x : T0, T0) -> x 2679 if (getArg().getType() == getType()) 2680 return getArg(); 2681 // addrcast(addrcast(x : T0, T1), T0) -> x 2682 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2683 if (prev.getArg().getType() == getType()) 2684 return prev.getArg(); 2685 return {}; 2686 } 2687 2688 //===----------------------------------------------------------------------===// 2689 // Folder for LLVM::GEPOp 2690 //===----------------------------------------------------------------------===// 2691 2692 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2693 // gep %x:T, 0 -> %x 2694 if (getBase().getType() == getType() && getIndices().size() == 1 && 2695 matchPattern(getIndices()[0], m_Zero())) 2696 return getBase(); 2697 return {}; 2698 } 2699 2700 //===----------------------------------------------------------------------===// 2701 // LLVMDialect initialization, type parsing, and registration. 2702 //===----------------------------------------------------------------------===// 2703 2704 void LLVMDialect::initialize() { 2705 addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>(); 2706 2707 // clang-format off 2708 addTypes<LLVMVoidType, 2709 LLVMPPCFP128Type, 2710 LLVMX86MMXType, 2711 LLVMTokenType, 2712 LLVMLabelType, 2713 LLVMMetadataType, 2714 LLVMFunctionType, 2715 LLVMPointerType, 2716 LLVMFixedVectorType, 2717 LLVMScalableVectorType, 2718 LLVMArrayType, 2719 LLVMStructType>(); 2720 // clang-format on 2721 addOperations< 2722 #define GET_OP_LIST 2723 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2724 , 2725 #define GET_OP_LIST 2726 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 2727 >(); 2728 2729 // Support unknown operations because not all LLVM operations are registered. 2730 allowUnknownOperations(); 2731 } 2732 2733 #define GET_OP_CLASSES 2734 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2735 2736 /// Parse a type registered to this dialect. 2737 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2738 return detail::parseType(parser); 2739 } 2740 2741 /// Print a type registered to this dialect. 2742 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2743 return detail::printType(type, os); 2744 } 2745 2746 LogicalResult LLVMDialect::verifyDataLayoutString( 2747 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2748 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2749 llvm::DataLayout::parse(descr); 2750 if (maybeDataLayout) 2751 return success(); 2752 2753 std::string message; 2754 llvm::raw_string_ostream messageStream(message); 2755 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2756 reportError("invalid data layout descriptor: " + messageStream.str()); 2757 return failure(); 2758 } 2759 2760 /// Verify LLVM dialect attributes. 2761 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2762 NamedAttribute attr) { 2763 // If the `llvm.loop` attribute is present, enforce the following structure, 2764 // which the module translation can assume. 2765 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2766 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2767 if (!loopAttr) 2768 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2769 << "' to be a dictionary attribute"; 2770 Optional<NamedAttribute> parallelAccessGroup = 2771 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2772 if (parallelAccessGroup) { 2773 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2774 if (!accessGroups) 2775 return op->emitOpError() 2776 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2777 << "' to be an array attribute"; 2778 for (Attribute attr : accessGroups) { 2779 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2780 if (!accessGroupRef) 2781 return op->emitOpError() 2782 << "expected '" << attr << "' to be a symbol reference"; 2783 StringAttr metadataName = accessGroupRef.getRootReference(); 2784 auto metadataOp = 2785 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2786 op->getParentOp(), metadataName); 2787 if (!metadataOp) 2788 return op->emitOpError() 2789 << "expected '" << attr << "' to reference a metadata op"; 2790 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2791 Operation *accessGroupOp = 2792 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2793 if (!accessGroupOp) 2794 return op->emitOpError() 2795 << "expected '" << attr << "' to reference an access_group op"; 2796 } 2797 } 2798 2799 Optional<NamedAttribute> loopOptions = 2800 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2801 if (loopOptions && !loopOptions->getValue().isa<LoopOptionsAttr>()) 2802 return op->emitOpError() 2803 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2804 << "' to be a `loopopts` attribute"; 2805 } 2806 2807 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2808 return op->emitOpError() 2809 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2810 << "' is permitted only in argument or result attributes"; 2811 } 2812 2813 // If the data layout attribute is present, it must use the LLVM data layout 2814 // syntax. Try parsing it and report errors in case of failure. Users of this 2815 // attribute may assume it is well-formed and can pass it to the (asserting) 2816 // llvm::DataLayout constructor. 2817 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2818 return success(); 2819 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2820 return verifyDataLayoutString( 2821 stringAttr.getValue(), 2822 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2823 2824 return op->emitOpError() << "expected '" 2825 << LLVM::LLVMDialect::getDataLayoutAttrName() 2826 << "' to be a string attributes"; 2827 } 2828 2829 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2830 Type annotatedType) { 2831 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2832 if (!structType) { 2833 const auto emitIncorrectAnnotatedType = [&op]() { 2834 return op->emitError() 2835 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2836 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2837 }; 2838 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2839 if (!ptrType) 2840 return emitIncorrectAnnotatedType(); 2841 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2842 if (!structType) 2843 return emitIncorrectAnnotatedType(); 2844 } 2845 2846 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2847 if (!arrAttrs) 2848 return op->emitError() << "expected '" 2849 << LLVMDialect::getStructAttrsAttrName() 2850 << "' to be an array attribute"; 2851 2852 if (structType.getBody().size() != arrAttrs.size()) 2853 return op->emitError() 2854 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2855 << "' must match the size of the annotated '!llvm.struct'"; 2856 return success(); 2857 } 2858 2859 static LogicalResult verifyFuncOpInterfaceStructAttr( 2860 Operation *op, Attribute attr, 2861 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) { 2862 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2863 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2864 return op->emitError() << "expected '" 2865 << LLVMDialect::getStructAttrsAttrName() 2866 << "' to be used on function-like operations"; 2867 } 2868 2869 /// Verify LLVMIR function argument attributes. 2870 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2871 unsigned regionIdx, 2872 unsigned argIdx, 2873 NamedAttribute argAttr) { 2874 // Check that llvm.noalias is a unit attribute. 2875 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2876 !argAttr.getValue().isa<UnitAttr>()) 2877 return op->emitError() 2878 << "expected llvm.noalias argument attribute to be a unit attribute"; 2879 // Check that llvm.align is an integer attribute. 2880 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2881 !argAttr.getValue().isa<IntegerAttr>()) 2882 return op->emitError() 2883 << "llvm.align argument attribute of non integer type"; 2884 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2885 return verifyFuncOpInterfaceStructAttr( 2886 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2887 return funcOp.getArgumentTypes()[argIdx]; 2888 }); 2889 } 2890 return success(); 2891 } 2892 2893 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2894 unsigned regionIdx, 2895 unsigned resIdx, 2896 NamedAttribute resAttr) { 2897 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2898 return verifyFuncOpInterfaceStructAttr( 2899 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2900 return funcOp.getResultTypes()[resIdx]; 2901 }); 2902 } 2903 return success(); 2904 } 2905 2906 //===----------------------------------------------------------------------===// 2907 // Utility functions. 2908 //===----------------------------------------------------------------------===// 2909 2910 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2911 StringRef name, StringRef value, 2912 LLVM::Linkage linkage) { 2913 assert(builder.getInsertionBlock() && 2914 builder.getInsertionBlock()->getParentOp() && 2915 "expected builder to point to a block constrained in an op"); 2916 auto module = 2917 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2918 assert(module && "builder points to an op outside of a module"); 2919 2920 // Create the global at the entry of the module. 2921 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2922 MLIRContext *ctx = builder.getContext(); 2923 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2924 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2925 loc, type, /*isConstant=*/true, linkage, name, 2926 builder.getStringAttr(value), /*alignment=*/0); 2927 2928 // Get the pointer to the first character in the global string. 2929 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2930 Value cst0 = builder.create<LLVM::ConstantOp>( 2931 loc, IntegerType::get(ctx, 64), 2932 builder.getIntegerAttr(builder.getIndexType(), 0)); 2933 return builder.create<LLVM::GEPOp>( 2934 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2935 ValueRange{cst0, cst0}); 2936 } 2937 2938 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2939 return op->hasTrait<OpTrait::SymbolTable>() && 2940 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2941 } 2942 2943 void FMFAttr::print(AsmPrinter &printer) const { 2944 printer << "<"; 2945 printer << stringifyFastmathFlags(this->getFlags()); 2946 printer << ">"; 2947 } 2948 2949 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2950 if (failed(parser.parseLess())) 2951 return {}; 2952 2953 FastmathFlags flags = {}; 2954 if (failed(parser.parseOptionalGreater())) { 2955 auto parseFlags = [&]() -> ParseResult { 2956 StringRef elemName; 2957 if (failed(parser.parseKeyword(&elemName))) 2958 return failure(); 2959 2960 auto elem = symbolizeFastmathFlags(elemName); 2961 if (!elem) 2962 return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2963 << elemName; 2964 2965 flags = flags | *elem; 2966 return success(); 2967 }; 2968 if (failed(parser.parseCommaSeparatedList(parseFlags)) || 2969 parser.parseGreater()) 2970 return {}; 2971 } 2972 2973 return FMFAttr::get(parser.getContext(), flags); 2974 } 2975 2976 void LinkageAttr::print(AsmPrinter &printer) const { 2977 printer << "<"; 2978 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2979 printer << stringifyEnum(getLinkage()); 2980 else 2981 printer << static_cast<uint64_t>(getLinkage()); 2982 printer << ">"; 2983 } 2984 2985 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2986 StringRef elemName; 2987 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2988 parser.parseGreater()) 2989 return {}; 2990 auto elem = linkage::symbolizeLinkage(elemName); 2991 if (!elem) { 2992 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2993 return {}; 2994 } 2995 Linkage linkage = *elem; 2996 return LinkageAttr::get(parser.getContext(), linkage); 2997 } 2998 2999 void CConvAttr::print(AsmPrinter &printer) const { 3000 printer << "<"; 3001 if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv()) 3002 printer << stringifyEnum(getCallingConv()); 3003 else 3004 printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv()); 3005 printer << ">"; 3006 } 3007 3008 Attribute CConvAttr::parse(AsmParser &parser, Type type) { 3009 StringRef convName; 3010 3011 if (parser.parseLess() || parser.parseKeyword(&convName) || 3012 parser.parseGreater()) 3013 return {}; 3014 auto cconv = cconv::symbolizeCConv(convName); 3015 if (!cconv) { 3016 parser.emitError(parser.getNameLoc(), "unknown calling convention: ") 3017 << convName; 3018 return {}; 3019 } 3020 CConv cconvVal = *cconv; 3021 return CConvAttr::get(parser.getContext(), cconvVal); 3022 } 3023 3024 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 3025 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 3026 3027 template <typename T> 3028 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 3029 Optional<T> value) { 3030 auto option = llvm::find_if( 3031 options, [tag](auto option) { return option.first == tag; }); 3032 if (option != options.end()) { 3033 if (value) 3034 option->second = *value; 3035 else 3036 options.erase(option); 3037 } else { 3038 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 3039 } 3040 return *this; 3041 } 3042 3043 LoopOptionsAttrBuilder & 3044 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 3045 return setOption(LoopOptionCase::disable_licm, value); 3046 } 3047 3048 /// Set the `interleave_count` option to the provided value. If no value 3049 /// is provided the option is deleted. 3050 LoopOptionsAttrBuilder & 3051 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 3052 return setOption(LoopOptionCase::interleave_count, count); 3053 } 3054 3055 /// Set the `disable_unroll` option to the provided value. If no value 3056 /// is provided the option is deleted. 3057 LoopOptionsAttrBuilder & 3058 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 3059 return setOption(LoopOptionCase::disable_unroll, value); 3060 } 3061 3062 /// Set the `disable_pipeline` option to the provided value. If no value 3063 /// is provided the option is deleted. 3064 LoopOptionsAttrBuilder & 3065 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 3066 return setOption(LoopOptionCase::disable_pipeline, value); 3067 } 3068 3069 /// Set the `pipeline_initiation_interval` option to the provided value. 3070 /// If no value is provided the option is deleted. 3071 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 3072 Optional<uint64_t> count) { 3073 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 3074 } 3075 3076 template <typename T> 3077 static Optional<T> 3078 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 3079 LoopOptionCase option) { 3080 auto it = 3081 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 3082 return optionPair.first < option; 3083 }); 3084 if (it == options.end()) 3085 return {}; 3086 return static_cast<T>(it->second); 3087 } 3088 3089 Optional<bool> LoopOptionsAttr::disableUnroll() { 3090 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 3091 } 3092 3093 Optional<bool> LoopOptionsAttr::disableLICM() { 3094 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 3095 } 3096 3097 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 3098 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 3099 } 3100 3101 /// Build the LoopOptions Attribute from a sorted array of individual options. 3102 LoopOptionsAttr LoopOptionsAttr::get( 3103 MLIRContext *context, 3104 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 3105 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 3106 "LoopOptionsAttr ctor expects a sorted options array"); 3107 return Base::get(context, sortedOptions); 3108 } 3109 3110 /// Build the LoopOptions Attribute from a sorted array of individual options. 3111 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 3112 LoopOptionsAttrBuilder &optionBuilders) { 3113 llvm::sort(optionBuilders.options, llvm::less_first()); 3114 return Base::get(context, optionBuilders.options); 3115 } 3116 3117 void LoopOptionsAttr::print(AsmPrinter &printer) const { 3118 printer << "<"; 3119 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 3120 printer << stringifyEnum(option.first) << " = "; 3121 switch (option.first) { 3122 case LoopOptionCase::disable_licm: 3123 case LoopOptionCase::disable_unroll: 3124 case LoopOptionCase::disable_pipeline: 3125 printer << (option.second ? "true" : "false"); 3126 break; 3127 case LoopOptionCase::interleave_count: 3128 case LoopOptionCase::pipeline_initiation_interval: 3129 printer << option.second; 3130 break; 3131 } 3132 }); 3133 printer << ">"; 3134 } 3135 3136 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 3137 if (failed(parser.parseLess())) 3138 return {}; 3139 3140 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 3141 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 3142 auto parseLoopOptions = [&]() -> ParseResult { 3143 StringRef optionName; 3144 if (parser.parseKeyword(&optionName)) 3145 return failure(); 3146 3147 auto option = symbolizeLoopOptionCase(optionName); 3148 if (!option) 3149 return parser.emitError(parser.getNameLoc(), "unknown loop option: ") 3150 << optionName; 3151 if (!seenOptions.insert(*option).second) 3152 return parser.emitError(parser.getNameLoc(), "loop option present twice"); 3153 if (failed(parser.parseEqual())) 3154 return failure(); 3155 3156 int64_t value; 3157 switch (*option) { 3158 case LoopOptionCase::disable_licm: 3159 case LoopOptionCase::disable_unroll: 3160 case LoopOptionCase::disable_pipeline: 3161 if (succeeded(parser.parseOptionalKeyword("true"))) 3162 value = 1; 3163 else if (succeeded(parser.parseOptionalKeyword("false"))) 3164 value = 0; 3165 else { 3166 return parser.emitError(parser.getNameLoc(), 3167 "expected boolean value 'true' or 'false'"); 3168 } 3169 break; 3170 case LoopOptionCase::interleave_count: 3171 case LoopOptionCase::pipeline_initiation_interval: 3172 if (failed(parser.parseInteger(value))) 3173 return parser.emitError(parser.getNameLoc(), "expected integer value"); 3174 break; 3175 } 3176 options.push_back(std::make_pair(*option, value)); 3177 return success(); 3178 }; 3179 if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater()) 3180 return {}; 3181 3182 llvm::sort(options, llvm::less_first()); 3183 return get(parser.getContext(), options); 3184 } 3185