1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Error.h" 33 #include "llvm/Support/Mutex.h" 34 #include "llvm/Support/SourceMgr.h" 35 36 #include <numeric> 37 38 using namespace mlir; 39 using namespace mlir::LLVM; 40 using mlir::LLVM::cconv::getMaxEnumValForCConv; 41 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 42 43 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 44 45 static constexpr const char kVolatileAttrName[] = "volatile_"; 46 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 47 static constexpr const char kElemTypeAttrName[] = "elem_type"; 48 49 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 50 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 51 #define GET_ATTRDEF_CLASSES 52 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 53 54 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 55 SmallVector<NamedAttribute, 8> filteredAttrs( 56 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 57 if (attr.getName() == "fastmathFlags") { 58 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 59 return defAttr != attr.getValue(); 60 } 61 return true; 62 })); 63 return filteredAttrs; 64 } 65 66 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 67 NamedAttrList &result) { 68 return parser.parseOptionalAttrDict(result); 69 } 70 71 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 72 DictionaryAttr attrs) { 73 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 74 } 75 76 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 77 /// fully defined llvm.func. 78 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 79 Operation *op, 80 SymbolTableCollection &symbolTable) { 81 StringRef name = symbol.getValue(); 82 auto func = 83 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 84 if (!func) 85 return op->emitOpError("'") 86 << name << "' does not reference a valid LLVM function"; 87 if (func.isExternal()) 88 return op->emitOpError("'") << name << "' does not have a definition"; 89 return success(); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // Printing, parsing and builder for LLVM::CmpOp. 94 //===----------------------------------------------------------------------===// 95 96 void ICmpOp::build(OpBuilder &builder, OperationState &result, 97 ICmpPredicate predicate, Value lhs, Value rhs) { 98 auto boolType = IntegerType::get(lhs.getType().getContext(), 1); 99 if (LLVM::isCompatibleVectorType(lhs.getType()) || 100 LLVM::isCompatibleVectorType(rhs.getType())) { 101 int64_t numLHSElements = 1, numRHSElements = 1; 102 if (LLVM::isCompatibleVectorType(lhs.getType())) 103 numLHSElements = 104 LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); 105 if (LLVM::isCompatibleVectorType(rhs.getType())) 106 numRHSElements = 107 LLVM::getVectorNumElements(rhs.getType()).getFixedValue(); 108 build(builder, result, 109 VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType), 110 predicate, lhs, rhs); 111 } else { 112 build(builder, result, boolType, predicate, lhs, rhs); 113 } 114 } 115 116 void ICmpOp::print(OpAsmPrinter &p) { 117 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 118 << ", " << getOperand(1); 119 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 120 p << " : " << getLhs().getType(); 121 } 122 123 void FCmpOp::print(OpAsmPrinter &p) { 124 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 125 << ", " << getOperand(1); 126 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 127 p << " : " << getLhs().getType(); 128 } 129 130 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 131 // attribute-dict? `:` type 132 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 133 // attribute-dict? `:` type 134 template <typename CmpPredicateType> 135 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 136 Builder &builder = parser.getBuilder(); 137 138 StringAttr predicateAttr; 139 OpAsmParser::UnresolvedOperand lhs, rhs; 140 Type type; 141 SMLoc predicateLoc, trailingTypeLoc; 142 if (parser.getCurrentLocation(&predicateLoc) || 143 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 144 parser.parseOperand(lhs) || parser.parseComma() || 145 parser.parseOperand(rhs) || 146 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 147 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 148 parser.resolveOperand(lhs, type, result.operands) || 149 parser.resolveOperand(rhs, type, result.operands)) 150 return failure(); 151 152 // Replace the string attribute `predicate` with an integer attribute. 153 int64_t predicateValue = 0; 154 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 155 Optional<ICmpPredicate> predicate = 156 symbolizeICmpPredicate(predicateAttr.getValue()); 157 if (!predicate) 158 return parser.emitError(predicateLoc) 159 << "'" << predicateAttr.getValue() 160 << "' is an incorrect value of the 'predicate' attribute"; 161 predicateValue = static_cast<int64_t>(*predicate); 162 } else { 163 Optional<FCmpPredicate> predicate = 164 symbolizeFCmpPredicate(predicateAttr.getValue()); 165 if (!predicate) 166 return parser.emitError(predicateLoc) 167 << "'" << predicateAttr.getValue() 168 << "' is an incorrect value of the 'predicate' attribute"; 169 predicateValue = static_cast<int64_t>(*predicate); 170 } 171 172 result.attributes.set("predicate", 173 parser.getBuilder().getI64IntegerAttr(predicateValue)); 174 175 // The result type is either i1 or a vector type <? x i1> if the inputs are 176 // vectors. 177 Type resultType = IntegerType::get(builder.getContext(), 1); 178 if (!isCompatibleType(type)) 179 return parser.emitError(trailingTypeLoc, 180 "expected LLVM dialect-compatible type"); 181 if (LLVM::isCompatibleVectorType(type)) { 182 if (LLVM::isScalableVectorType(type)) { 183 resultType = LLVM::getVectorType( 184 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 185 /*isScalable=*/true); 186 } else { 187 resultType = LLVM::getVectorType( 188 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 189 /*isScalable=*/false); 190 } 191 } 192 193 result.addTypes({resultType}); 194 return success(); 195 } 196 197 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 198 return parseCmpOp<ICmpPredicate>(parser, result); 199 } 200 201 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 202 return parseCmpOp<FCmpPredicate>(parser, result); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // Printing, parsing and verification for LLVM::AllocaOp. 207 //===----------------------------------------------------------------------===// 208 209 void AllocaOp::print(OpAsmPrinter &p) { 210 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 211 if (!elemTy) 212 elemTy = *getElemType(); 213 214 auto funcTy = 215 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 216 217 p << ' ' << getArraySize() << " x " << elemTy; 218 if (getAlignment() && *getAlignment() != 0) 219 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName}); 220 else 221 p.printOptionalAttrDict((*this)->getAttrs(), 222 {"alignment", kElemTypeAttrName}); 223 p << " : " << funcTy; 224 } 225 226 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 227 // `:` type `,` type 228 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 229 OpAsmParser::UnresolvedOperand arraySize; 230 Type type, elemType; 231 SMLoc trailingTypeLoc; 232 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 233 parser.parseType(elemType) || 234 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 235 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 236 return failure(); 237 238 Optional<NamedAttribute> alignmentAttr = 239 result.attributes.getNamed("alignment"); 240 if (alignmentAttr.hasValue()) { 241 auto alignmentInt = 242 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 243 if (!alignmentInt) 244 return parser.emitError(parser.getNameLoc(), 245 "expected integer alignment"); 246 if (alignmentInt.getValue().isNullValue()) 247 result.attributes.erase("alignment"); 248 } 249 250 // Extract the result type from the trailing function type. 251 auto funcType = type.dyn_cast<FunctionType>(); 252 if (!funcType || funcType.getNumInputs() != 1 || 253 funcType.getNumResults() != 1) 254 return parser.emitError( 255 trailingTypeLoc, 256 "expected trailing function type with one argument and one result"); 257 258 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 259 return failure(); 260 261 Type resultType = funcType.getResult(0); 262 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) { 263 if (ptrResultType.isOpaque()) 264 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 265 } 266 267 result.addTypes({funcType.getResult(0)}); 268 return success(); 269 } 270 271 /// Checks that the elemental type is present in either the pointer type or 272 /// the attribute, but not both. 273 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, 274 Optional<Type> ptrElementType) { 275 if (ptrType.isOpaque() && !ptrElementType.hasValue()) { 276 return op->emitOpError() << "expected '" << kElemTypeAttrName 277 << "' attribute if opaque pointer type is used"; 278 } 279 if (!ptrType.isOpaque() && ptrElementType.hasValue()) { 280 return op->emitOpError() 281 << "unexpected '" << kElemTypeAttrName 282 << "' attribute when non-opaque pointer type is used"; 283 } 284 return success(); 285 } 286 287 LogicalResult AllocaOp::verify() { 288 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(), 289 getElemType()); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // LLVM::BrOp 294 //===----------------------------------------------------------------------===// 295 296 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 297 assert(index == 0 && "invalid successor index"); 298 return SuccessorOperands(getDestOperandsMutable()); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // LLVM::CondBrOp 303 //===----------------------------------------------------------------------===// 304 305 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 306 assert(index < getNumSuccessors() && "invalid successor index"); 307 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 308 : getFalseDestOperandsMutable()); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // LLVM::SwitchOp 313 //===----------------------------------------------------------------------===// 314 315 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 316 Block *defaultDestination, ValueRange defaultOperands, 317 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 318 ArrayRef<ValueRange> caseOperands, 319 ArrayRef<int32_t> branchWeights) { 320 ElementsAttr caseValuesAttr; 321 if (!caseValues.empty()) 322 caseValuesAttr = builder.getI32VectorAttr(caseValues); 323 324 ElementsAttr weightsAttr; 325 if (!branchWeights.empty()) 326 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 327 328 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 329 weightsAttr, defaultDestination, caseDestinations); 330 } 331 332 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 333 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 334 static ParseResult parseSwitchOpCases( 335 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 336 SmallVectorImpl<Block *> &caseDestinations, 337 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 338 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 339 SmallVector<APInt> values; 340 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 341 do { 342 int64_t value = 0; 343 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 344 if (values.empty() && !integerParseResult.hasValue()) 345 return success(); 346 347 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 348 return failure(); 349 values.push_back(APInt(bitWidth, value)); 350 351 Block *destination; 352 SmallVector<OpAsmParser::UnresolvedOperand> operands; 353 SmallVector<Type> operandTypes; 354 if (parser.parseColon() || parser.parseSuccessor(destination)) 355 return failure(); 356 if (!parser.parseOptionalLParen()) { 357 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 358 /*allowResultNumber=*/false) || 359 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 360 return failure(); 361 } 362 caseDestinations.push_back(destination); 363 caseOperands.emplace_back(operands); 364 caseOperandTypes.emplace_back(operandTypes); 365 } while (!parser.parseOptionalComma()); 366 367 ShapedType caseValueType = 368 VectorType::get(static_cast<int64_t>(values.size()), flagType); 369 caseValues = DenseIntElementsAttr::get(caseValueType, values); 370 return success(); 371 } 372 373 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 374 ElementsAttr caseValues, 375 SuccessorRange caseDestinations, 376 OperandRangeRange caseOperands, 377 const TypeRangeRange &caseOperandTypes) { 378 if (!caseValues) 379 return; 380 381 size_t index = 0; 382 llvm::interleave( 383 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 384 [&](auto i) { 385 p << " "; 386 p << std::get<0>(i).getLimitedValue(); 387 p << ": "; 388 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 389 }, 390 [&] { 391 p << ','; 392 p.printNewline(); 393 }); 394 p.printNewline(); 395 } 396 397 LogicalResult SwitchOp::verify() { 398 if ((!getCaseValues() && !getCaseDestinations().empty()) || 399 (getCaseValues() && 400 getCaseValues()->size() != 401 static_cast<int64_t>(getCaseDestinations().size()))) 402 return emitOpError("expects number of case values to match number of " 403 "case destinations"); 404 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 405 return emitError("expects number of branch weights to match number of " 406 "successors: ") 407 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 408 return success(); 409 } 410 411 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 412 assert(index < getNumSuccessors() && "invalid successor index"); 413 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 414 : getCaseOperandsMutable(index - 1)); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // Code for LLVM::GEPOp. 419 //===----------------------------------------------------------------------===// 420 421 constexpr int GEPOp::kDynamicIndex; 422 423 namespace { 424 /// Base class for llvm::Error related to GEP index. 425 class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> { 426 protected: 427 unsigned indexPos; 428 429 public: 430 static char ID; 431 432 std::error_code convertToErrorCode() const override { 433 return llvm::inconvertibleErrorCode(); 434 } 435 436 explicit GEPIndexError(unsigned pos) : indexPos(pos) {} 437 }; 438 439 /// llvm::Error for out-of-bound GEP index. 440 struct GEPIndexOutOfBoundError 441 : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> { 442 static char ID; 443 444 using ErrorInfo::ErrorInfo; 445 446 void log(llvm::raw_ostream &os) const override { 447 os << "index " << indexPos << " indexing a struct is out of bounds"; 448 } 449 }; 450 451 /// llvm::Error for non-static GEP index indexing a struct. 452 struct GEPStaticIndexError 453 : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> { 454 static char ID; 455 456 using ErrorInfo::ErrorInfo; 457 458 void log(llvm::raw_ostream &os) const override { 459 os << "expected index " << indexPos << " indexing a struct " 460 << "to be constant"; 461 } 462 }; 463 } // end anonymous namespace 464 465 char GEPIndexError::ID = 0; 466 char GEPIndexOutOfBoundError::ID = 0; 467 char GEPStaticIndexError::ID = 0; 468 469 /// For the given `structIndices` and `indices`, check if they're complied 470 /// with `baseGEPType`, especially check against LLVMStructTypes nested within, 471 /// and refine/promote struct index from `indices` to `updatedStructIndices` 472 /// if the latter argument is not null. 473 static llvm::Error 474 recordStructIndices(Type baseGEPType, unsigned indexPos, 475 ArrayRef<int32_t> structIndices, ValueRange indices, 476 SmallVectorImpl<int32_t> *updatedStructIndices, 477 SmallVectorImpl<Value> *remainingIndices) { 478 if (indexPos >= structIndices.size()) 479 // Stop searching 480 return llvm::Error::success(); 481 482 int32_t gepIndex = structIndices[indexPos]; 483 bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; 484 485 unsigned dynamicIndexPos = indexPos; 486 if (!isStaticIndex) 487 dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), 488 LLVM::GEPOp::kDynamicIndex) - 1; 489 490 return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType) 491 .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error { 492 // We don't always want to refine the index (e.g. when performing 493 // verification), so we only refine when updatedStructIndices is not 494 // null. 495 if (!isStaticIndex && updatedStructIndices) { 496 // Try to refine. 497 APInt staticIndexValue; 498 isStaticIndex = matchPattern(indices[dynamicIndexPos], 499 m_ConstantInt(&staticIndexValue)); 500 if (isStaticIndex) { 501 assert(staticIndexValue.getBitWidth() <= 64 && 502 llvm::isInt<32>(staticIndexValue.getLimitedValue()) && 503 "struct index can't fit within int32_t"); 504 gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 505 } 506 } 507 if (!isStaticIndex) 508 return llvm::make_error<GEPStaticIndexError>(indexPos); 509 510 ArrayRef<Type> elementTypes = structType.getBody(); 511 if (gepIndex < 0 || 512 static_cast<size_t>(gepIndex) >= elementTypes.size()) 513 return llvm::make_error<GEPIndexOutOfBoundError>(indexPos); 514 515 if (updatedStructIndices) 516 (*updatedStructIndices)[indexPos] = gepIndex; 517 518 // Instead of recusively going into every children types, we only 519 // dive into the one indexed by gepIndex. 520 return recordStructIndices(elementTypes[gepIndex], indexPos + 1, 521 structIndices, indices, updatedStructIndices, 522 remainingIndices); 523 }) 524 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 525 LLVMArrayType>([&](auto containerType) -> llvm::Error { 526 // Currently we don't refine non-struct index even if it's static. 527 if (remainingIndices) 528 remainingIndices->push_back(indices[dynamicIndexPos]); 529 return recordStructIndices(containerType.getElementType(), indexPos + 1, 530 structIndices, indices, updatedStructIndices, 531 remainingIndices); 532 }) 533 .Default( 534 [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); 535 } 536 537 /// Driver function around `recordStructIndices`. Note that we always check 538 /// from the second GEP index since the first one is always dynamic. 539 static llvm::Error 540 findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices, 541 ValueRange indices, 542 SmallVectorImpl<int32_t> *updatedStructIndices = nullptr, 543 SmallVectorImpl<Value> *remainingIndices = nullptr) { 544 if (remainingIndices) 545 // The first GEP index is always dynamic. 546 remainingIndices->push_back(indices[0]); 547 return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, 548 indices, updatedStructIndices, remainingIndices); 549 } 550 551 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 552 Value basePtr, ValueRange operands, 553 ArrayRef<NamedAttribute> attributes) { 554 build(builder, result, resultType, basePtr, operands, 555 SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes); 556 } 557 558 /// Returns the elemental type of any LLVM-compatible vector type or self. 559 static Type extractVectorElementType(Type type) { 560 if (auto vectorType = type.dyn_cast<VectorType>()) 561 return vectorType.getElementType(); 562 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 563 return scalableVectorType.getElementType(); 564 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 565 return fixedVectorType.getElementType(); 566 return type; 567 } 568 569 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 570 Value basePtr, ValueRange indices, 571 ArrayRef<int32_t> structIndices, 572 ArrayRef<NamedAttribute> attributes) { 573 auto ptrType = 574 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>(); 575 assert(!ptrType.isOpaque() && 576 "expected non-opaque pointer, provide elementType explicitly when " 577 "opaque pointers are used"); 578 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, 579 structIndices, attributes); 580 } 581 582 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 583 Type elementType, Value basePtr, ValueRange indices, 584 ArrayRef<int32_t> structIndices, 585 ArrayRef<NamedAttribute> attributes) { 586 SmallVector<Value> remainingIndices; 587 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 588 structIndices.end()); 589 if (llvm::Error err = 590 findStructIndices(elementType, structIndices, indices, 591 &updatedStructIndices, &remainingIndices)) 592 llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); 593 594 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 595 updatedStructIndices, kDynamicIndex)) && 596 "expected as many index operands as dynamic index attr elements"); 597 598 result.addTypes(resultType); 599 result.addAttributes(attributes); 600 result.addAttribute("structIndices", 601 builder.getI32TensorAttr(updatedStructIndices)); 602 if (extractVectorElementType(basePtr.getType()) 603 .cast<LLVMPointerType>() 604 .isOpaque()) 605 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 606 result.addOperands(basePtr); 607 result.addOperands(remainingIndices); 608 } 609 610 static ParseResult 611 parseGEPIndices(OpAsmParser &parser, 612 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 613 DenseIntElementsAttr &structIndices) { 614 SmallVector<int32_t> constantIndices; 615 616 auto idxParser = [&]() -> ParseResult { 617 int32_t constantIndex; 618 OptionalParseResult parsedInteger = 619 parser.parseOptionalInteger(constantIndex); 620 if (parsedInteger.hasValue()) { 621 if (failed(parsedInteger.getValue())) 622 return failure(); 623 constantIndices.push_back(constantIndex); 624 return success(); 625 } 626 627 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 628 return parser.parseOperand(indices.emplace_back()); 629 }; 630 if (parser.parseCommaSeparatedList(idxParser)) 631 return failure(); 632 633 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 634 return success(); 635 } 636 637 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 638 OperandRange indices, 639 DenseIntElementsAttr structIndices) { 640 unsigned operandIdx = 0; 641 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 642 [&](int32_t cst) { 643 if (cst == LLVM::GEPOp::kDynamicIndex) 644 printer.printOperand(indices[operandIdx++]); 645 else 646 printer << cst; 647 }); 648 } 649 650 LogicalResult LLVM::GEPOp::verify() { 651 if (failed(verifyOpaquePtr( 652 getOperation(), 653 extractVectorElementType(getType()).cast<LLVMPointerType>(), 654 getElemType()))) 655 return failure(); 656 657 auto structIndexRange = getStructIndices().getValues<int32_t>(); 658 // structIndexRange is a kind of iterator, which cannot be converted 659 // to ArrayRef directly. 660 SmallVector<int32_t> structIndices(structIndexRange.size()); 661 for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size())) 662 structIndices[i] = structIndexRange[i]; 663 if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices, 664 getIndices())) 665 return emitOpError() << llvm::toString(std::move(err)); 666 667 return success(); 668 } 669 670 Type LLVM::GEPOp::getSourceElementType() { 671 if (Optional<Type> elemType = getElemType()) 672 return *elemType; 673 674 return extractVectorElementType(getBase().getType()) 675 .cast<LLVMPointerType>() 676 .getElementType(); 677 } 678 679 //===----------------------------------------------------------------------===// 680 // Builder, printer and parser for for LLVM::LoadOp. 681 //===----------------------------------------------------------------------===// 682 683 LogicalResult verifySymbolAttribute( 684 Operation *op, StringRef attributeName, 685 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 686 verifySymbolType) { 687 if (Attribute attribute = op->getAttr(attributeName)) { 688 // The attribute is already verified to be a symbol ref array attribute via 689 // a constraint in the operation definition. 690 for (SymbolRefAttr symbolRef : 691 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 692 StringAttr metadataName = symbolRef.getRootReference(); 693 StringAttr symbolName = symbolRef.getLeafReference(); 694 // We want @metadata::@symbol, not just @symbol 695 if (metadataName == symbolName) { 696 return op->emitOpError() << "expected '" << symbolRef 697 << "' to specify a fully qualified reference"; 698 } 699 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 700 op->getParentOp(), metadataName); 701 if (!metadataOp) 702 return op->emitOpError() 703 << "expected '" << symbolRef << "' to reference a metadata op"; 704 Operation *symbolOp = 705 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 706 if (!symbolOp) 707 return op->emitOpError() 708 << "expected '" << symbolRef << "' to be a valid reference"; 709 if (failed(verifySymbolType(symbolOp, symbolRef))) { 710 return failure(); 711 } 712 } 713 } 714 return success(); 715 } 716 717 // Verifies that metadata ops are wired up properly. 718 template <typename OpTy> 719 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 720 auto verifySymbolType = [op](Operation *symbolOp, 721 SymbolRefAttr symbolRef) -> LogicalResult { 722 if (!isa<OpTy>(symbolOp)) { 723 return op->emitOpError() 724 << "expected '" << symbolRef << "' to resolve to a " 725 << OpTy::getOperationName(); 726 } 727 return success(); 728 }; 729 730 return verifySymbolAttribute(op, attributeName, verifySymbolType); 731 } 732 733 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 734 // access_groups 735 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 736 op, LLVMDialect::getAccessGroupsAttrName()))) 737 return failure(); 738 739 // alias_scopes 740 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 741 op, LLVMDialect::getAliasScopesAttrName()))) 742 return failure(); 743 744 // noalias_scopes 745 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 746 op, LLVMDialect::getNoAliasScopesAttrName()))) 747 return failure(); 748 749 return success(); 750 } 751 752 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 753 754 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 755 Value addr, unsigned alignment, bool isVolatile, 756 bool isNonTemporal) { 757 result.addOperands(addr); 758 result.addTypes(t); 759 if (isVolatile) 760 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 761 if (isNonTemporal) 762 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 763 if (alignment != 0) 764 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 765 } 766 767 void LoadOp::print(OpAsmPrinter &p) { 768 p << ' '; 769 if (getVolatile_()) 770 p << "volatile "; 771 p << getAddr(); 772 p.printOptionalAttrDict((*this)->getAttrs(), 773 {kVolatileAttrName, kElemTypeAttrName}); 774 p << " : " << getAddr().getType(); 775 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 776 p << " -> " << getType(); 777 } 778 779 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 780 // the resulting type if any, null type if opaque pointers are used, and None 781 // if the given type is not the pointer type. 782 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type, 783 SMLoc trailingTypeLoc) { 784 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 785 if (!llvmTy) { 786 parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); 787 return llvm::None; 788 } 789 return llvmTy.getElementType(); 790 } 791 792 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 793 // (`->` type)? 794 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 795 OpAsmParser::UnresolvedOperand addr; 796 Type type; 797 SMLoc trailingTypeLoc; 798 799 if (succeeded(parser.parseOptionalKeyword("volatile"))) 800 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 801 802 if (parser.parseOperand(addr) || 803 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 804 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 805 parser.resolveOperand(addr, type, result.operands)) 806 return failure(); 807 808 Optional<Type> elemTy = 809 getLoadStoreElementType(parser, type, trailingTypeLoc); 810 if (!elemTy) 811 return failure(); 812 if (*elemTy) { 813 result.addTypes(*elemTy); 814 return success(); 815 } 816 817 Type trailingType; 818 if (parser.parseArrow() || parser.parseType(trailingType)) 819 return failure(); 820 result.addTypes(trailingType); 821 return success(); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // Builder, printer and parser for LLVM::StoreOp. 826 //===----------------------------------------------------------------------===// 827 828 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 829 830 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 831 Value addr, unsigned alignment, bool isVolatile, 832 bool isNonTemporal) { 833 result.addOperands({value, addr}); 834 result.addTypes({}); 835 if (isVolatile) 836 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 837 if (isNonTemporal) 838 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 839 if (alignment != 0) 840 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 841 } 842 843 void StoreOp::print(OpAsmPrinter &p) { 844 p << ' '; 845 if (getVolatile_()) 846 p << "volatile "; 847 p << getValue() << ", " << getAddr(); 848 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 849 p << " : "; 850 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 851 p << getValue().getType() << ", "; 852 p << getAddr().getType(); 853 } 854 855 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 856 // attribute-dict? `:` type (`,` type)? 857 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 858 OpAsmParser::UnresolvedOperand addr, value; 859 Type type; 860 SMLoc trailingTypeLoc; 861 862 if (succeeded(parser.parseOptionalKeyword("volatile"))) 863 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 864 865 if (parser.parseOperand(value) || parser.parseComma() || 866 parser.parseOperand(addr) || 867 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 868 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 869 return failure(); 870 871 Type operandType; 872 if (succeeded(parser.parseOptionalComma())) { 873 operandType = type; 874 if (parser.parseType(type)) 875 return failure(); 876 } else { 877 Optional<Type> maybeOperandType = 878 getLoadStoreElementType(parser, type, trailingTypeLoc); 879 if (!maybeOperandType) 880 return failure(); 881 operandType = *maybeOperandType; 882 } 883 884 if (parser.resolveOperand(value, operandType, result.operands) || 885 parser.resolveOperand(addr, type, result.operands)) 886 return failure(); 887 888 return success(); 889 } 890 891 ///===---------------------------------------------------------------------===// 892 /// LLVM::InvokeOp 893 ///===---------------------------------------------------------------------===// 894 895 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 896 assert(index < getNumSuccessors() && "invalid successor index"); 897 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 898 : getUnwindDestOperandsMutable()); 899 } 900 901 CallInterfaceCallable InvokeOp::getCallableForCallee() { 902 // Direct call. 903 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 904 return calleeAttr; 905 // Indirect call, callee Value is the first operand. 906 return getOperand(0); 907 } 908 909 Operation::operand_range InvokeOp::getArgOperands() { 910 return getOperands().drop_front(getCallee().hasValue() ? 0 : 1); 911 } 912 913 LogicalResult InvokeOp::verify() { 914 if (getNumResults() > 1) 915 return emitOpError("must have 0 or 1 result"); 916 917 Block *unwindDest = getUnwindDest(); 918 if (unwindDest->empty()) 919 return emitError("must have at least one operation in unwind destination"); 920 921 // In unwind destination, first operation must be LandingpadOp 922 if (!isa<LandingpadOp>(unwindDest->front())) 923 return emitError("first operation in unwind destination should be a " 924 "llvm.landingpad operation"); 925 926 return success(); 927 } 928 929 void InvokeOp::print(OpAsmPrinter &p) { 930 auto callee = getCallee(); 931 bool isDirect = callee.hasValue(); 932 933 p << ' '; 934 935 // Either function name or pointer 936 if (isDirect) 937 p.printSymbolName(callee.getValue()); 938 else 939 p << getOperand(0); 940 941 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 942 p << " to "; 943 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 944 p << " unwind "; 945 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 946 947 p.printOptionalAttrDict((*this)->getAttrs(), 948 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 949 p << " : "; 950 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 951 getResultTypes()); 952 } 953 954 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 955 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 956 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 957 /// attribute-dict? `:` function-type 958 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 959 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 960 FunctionType funcType; 961 SymbolRefAttr funcAttr; 962 SMLoc trailingTypeLoc; 963 Block *normalDest, *unwindDest; 964 SmallVector<Value, 4> normalOperands, unwindOperands; 965 Builder &builder = parser.getBuilder(); 966 967 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 968 // case of an indirect call, there will be 1 operand before `(`. In case of a 969 // direct call, there will be no operands and the parser will stop at the 970 // function identifier without complaining. 971 if (parser.parseOperandList(operands)) 972 return failure(); 973 bool isDirect = operands.empty(); 974 975 // Optionally parse a function identifier. 976 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 977 return failure(); 978 979 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 980 parser.parseKeyword("to") || 981 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 982 parser.parseKeyword("unwind") || 983 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 984 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 985 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 986 return failure(); 987 988 if (isDirect) { 989 // Make sure types match. 990 if (parser.resolveOperands(operands, funcType.getInputs(), 991 parser.getNameLoc(), result.operands)) 992 return failure(); 993 result.addTypes(funcType.getResults()); 994 } else { 995 // Construct the LLVM IR Dialect function type that the first operand 996 // should match. 997 if (funcType.getNumResults() > 1) 998 return parser.emitError(trailingTypeLoc, 999 "expected function with 0 or 1 result"); 1000 1001 Type llvmResultType; 1002 if (funcType.getNumResults() == 0) { 1003 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1004 } else { 1005 llvmResultType = funcType.getResult(0); 1006 if (!isCompatibleType(llvmResultType)) 1007 return parser.emitError(trailingTypeLoc, 1008 "expected result to have LLVM type"); 1009 } 1010 1011 SmallVector<Type, 8> argTypes; 1012 argTypes.reserve(funcType.getNumInputs()); 1013 for (Type ty : funcType.getInputs()) { 1014 if (isCompatibleType(ty)) 1015 argTypes.push_back(ty); 1016 else 1017 return parser.emitError(trailingTypeLoc, 1018 "expected LLVM types as inputs"); 1019 } 1020 1021 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1022 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1023 1024 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 1025 1026 // Make sure that the first operand (indirect callee) matches the wrapped 1027 // LLVM IR function type, and that the types of the other call operands 1028 // match the types of the function arguments. 1029 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1030 parser.resolveOperands(funcArguments, funcType.getInputs(), 1031 parser.getNameLoc(), result.operands)) 1032 return failure(); 1033 1034 result.addTypes(llvmResultType); 1035 } 1036 result.addSuccessors({normalDest, unwindDest}); 1037 result.addOperands(normalOperands); 1038 result.addOperands(unwindOperands); 1039 1040 result.addAttribute( 1041 InvokeOp::getOperandSegmentSizeAttr(), 1042 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 1043 static_cast<int32_t>(normalOperands.size()), 1044 static_cast<int32_t>(unwindOperands.size())})); 1045 return success(); 1046 } 1047 1048 ///===----------------------------------------------------------------------===// 1049 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 1050 ///===----------------------------------------------------------------------===// 1051 1052 LogicalResult LandingpadOp::verify() { 1053 Value value; 1054 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 1055 if (!func.getPersonality()) 1056 return emitError( 1057 "llvm.landingpad needs to be in a function with a personality"); 1058 } 1059 1060 if (!getCleanup() && getOperands().empty()) 1061 return emitError("landingpad instruction expects at least one clause or " 1062 "cleanup attribute"); 1063 1064 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 1065 value = getOperand(idx); 1066 bool isFilter = value.getType().isa<LLVMArrayType>(); 1067 if (isFilter) { 1068 // FIXME: Verify filter clauses when arrays are appropriately handled 1069 } else { 1070 // catch - global addresses only. 1071 // Bitcast ops should have global addresses as their args. 1072 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 1073 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 1074 continue; 1075 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 1076 << "global addresses expected as operand to " 1077 "bitcast used in clauses for landingpad"; 1078 } 1079 // NullOp and AddressOfOp allowed 1080 if (value.getDefiningOp<NullOp>()) 1081 continue; 1082 if (value.getDefiningOp<AddressOfOp>()) 1083 continue; 1084 return emitError("clause #") 1085 << idx << " is not a known constant - null, addressof, bitcast"; 1086 } 1087 } 1088 return success(); 1089 } 1090 1091 void LandingpadOp::print(OpAsmPrinter &p) { 1092 p << (getCleanup() ? " cleanup " : " "); 1093 1094 // Clauses 1095 for (auto value : getOperands()) { 1096 // Similar to llvm - if clause is an array type then it is filter 1097 // clause else catch clause 1098 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 1099 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1100 << value.getType() << ") "; 1101 } 1102 1103 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1104 1105 p << ": " << getType(); 1106 } 1107 1108 /// <operation> ::= `llvm.landingpad` `cleanup`? 1109 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1110 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1111 // Check for cleanup 1112 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1113 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1114 1115 // Parse clauses with types 1116 while (succeeded(parser.parseOptionalLParen()) && 1117 (succeeded(parser.parseOptionalKeyword("filter")) || 1118 succeeded(parser.parseOptionalKeyword("catch")))) { 1119 OpAsmParser::UnresolvedOperand operand; 1120 Type ty; 1121 if (parser.parseOperand(operand) || parser.parseColon() || 1122 parser.parseType(ty) || 1123 parser.resolveOperand(operand, ty, result.operands) || 1124 parser.parseRParen()) 1125 return failure(); 1126 } 1127 1128 Type type; 1129 if (parser.parseColon() || parser.parseType(type)) 1130 return failure(); 1131 1132 result.addTypes(type); 1133 return success(); 1134 } 1135 1136 //===----------------------------------------------------------------------===// 1137 // Verifying/Printing/parsing for LLVM::CallOp. 1138 //===----------------------------------------------------------------------===// 1139 1140 CallInterfaceCallable CallOp::getCallableForCallee() { 1141 // Direct call. 1142 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 1143 return calleeAttr; 1144 // Indirect call, callee Value is the first operand. 1145 return getOperand(0); 1146 } 1147 1148 Operation::operand_range CallOp::getArgOperands() { 1149 return getOperands().drop_front(getCallee().hasValue() ? 0 : 1); 1150 } 1151 1152 LogicalResult CallOp::verify() { 1153 if (getNumResults() > 1) 1154 return emitOpError("must have 0 or 1 result"); 1155 1156 // Type for the callee, we'll get it differently depending if it is a direct 1157 // or indirect call. 1158 Type fnType; 1159 1160 bool isIndirect = false; 1161 1162 // If this is an indirect call, the callee attribute is missing. 1163 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1164 if (!calleeName) { 1165 isIndirect = true; 1166 if (!getNumOperands()) 1167 return emitOpError( 1168 "must have either a `callee` attribute or at least an operand"); 1169 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 1170 if (!ptrType) 1171 return emitOpError("indirect call expects a pointer as callee: ") 1172 << ptrType; 1173 fnType = ptrType.getElementType(); 1174 } else { 1175 Operation *callee = 1176 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1177 if (!callee) 1178 return emitOpError() 1179 << "'" << calleeName.getValue() 1180 << "' does not reference a symbol in the current scope"; 1181 auto fn = dyn_cast<LLVMFuncOp>(callee); 1182 if (!fn) 1183 return emitOpError() << "'" << calleeName.getValue() 1184 << "' does not reference a valid LLVM function"; 1185 1186 fnType = fn.getFunctionType(); 1187 } 1188 1189 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1190 if (!funcType) 1191 return emitOpError("callee does not have a functional type: ") << fnType; 1192 1193 // Verify that the operand and result types match the callee. 1194 1195 if (!funcType.isVarArg() && 1196 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1197 return emitOpError() << "incorrect number of operands (" 1198 << (getNumOperands() - isIndirect) 1199 << ") for callee (expecting: " 1200 << funcType.getNumParams() << ")"; 1201 1202 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1203 return emitOpError() << "incorrect number of operands (" 1204 << (getNumOperands() - isIndirect) 1205 << ") for varargs callee (expecting at least: " 1206 << funcType.getNumParams() << ")"; 1207 1208 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1209 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1210 return emitOpError() << "operand type mismatch for operand " << i << ": " 1211 << getOperand(i + isIndirect).getType() 1212 << " != " << funcType.getParamType(i); 1213 1214 if (getNumResults() == 0 && 1215 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1216 return emitOpError() << "expected function call to produce a value"; 1217 1218 if (getNumResults() != 0 && 1219 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1220 return emitOpError() 1221 << "calling function with void result must not produce values"; 1222 1223 if (getNumResults() > 1) 1224 return emitOpError() 1225 << "expected LLVM function call to produce 0 or 1 result"; 1226 1227 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1228 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1229 << " != " << funcType.getReturnType(); 1230 1231 return success(); 1232 } 1233 1234 void CallOp::print(OpAsmPrinter &p) { 1235 auto callee = getCallee(); 1236 bool isDirect = callee.hasValue(); 1237 1238 // Print the direct callee if present as a function attribute, or an indirect 1239 // callee (first operand) otherwise. 1240 p << ' '; 1241 if (isDirect) 1242 p.printSymbolName(callee.getValue()); 1243 else 1244 p << getOperand(0); 1245 1246 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1247 p << '(' << args << ')'; 1248 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1249 1250 // Reconstruct the function MLIR function type from operand and result types. 1251 p << " : "; 1252 p.printFunctionalType(args.getTypes(), getResultTypes()); 1253 } 1254 1255 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1256 // attribute-dict? `:` function-type 1257 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1258 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1259 Type type; 1260 SymbolRefAttr funcAttr; 1261 SMLoc trailingTypeLoc; 1262 1263 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1264 // case of an indirect call, there will be 1 operand before `(`. In case of a 1265 // direct call, there will be no operands and the parser will stop at the 1266 // function identifier without complaining. 1267 if (parser.parseOperandList(operands)) 1268 return failure(); 1269 bool isDirect = operands.empty(); 1270 1271 // Optionally parse a function identifier. 1272 if (isDirect) 1273 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1274 return failure(); 1275 1276 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1277 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1278 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1279 return failure(); 1280 1281 auto funcType = type.dyn_cast<FunctionType>(); 1282 if (!funcType) 1283 return parser.emitError(trailingTypeLoc, "expected function type"); 1284 if (funcType.getNumResults() > 1) 1285 return parser.emitError(trailingTypeLoc, 1286 "expected function with 0 or 1 result"); 1287 if (isDirect) { 1288 // Make sure types match. 1289 if (parser.resolveOperands(operands, funcType.getInputs(), 1290 parser.getNameLoc(), result.operands)) 1291 return failure(); 1292 if (funcType.getNumResults() != 0 && 1293 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1294 result.addTypes(funcType.getResults()); 1295 } else { 1296 Builder &builder = parser.getBuilder(); 1297 Type llvmResultType; 1298 if (funcType.getNumResults() == 0) { 1299 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1300 } else { 1301 llvmResultType = funcType.getResult(0); 1302 if (!isCompatibleType(llvmResultType)) 1303 return parser.emitError(trailingTypeLoc, 1304 "expected result to have LLVM type"); 1305 } 1306 1307 SmallVector<Type, 8> argTypes; 1308 argTypes.reserve(funcType.getNumInputs()); 1309 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1310 auto argType = funcType.getInput(i); 1311 if (!isCompatibleType(argType)) 1312 return parser.emitError(trailingTypeLoc, 1313 "expected LLVM types as inputs"); 1314 argTypes.push_back(argType); 1315 } 1316 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1317 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1318 1319 auto funcArguments = 1320 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1321 1322 // Make sure that the first operand (indirect callee) matches the wrapped 1323 // LLVM IR function type, and that the types of the other call operands 1324 // match the types of the function arguments. 1325 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1326 parser.resolveOperands(funcArguments, funcType.getInputs(), 1327 parser.getNameLoc(), result.operands)) 1328 return failure(); 1329 1330 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1331 result.addTypes(llvmResultType); 1332 } 1333 1334 return success(); 1335 } 1336 1337 //===----------------------------------------------------------------------===// 1338 // Printing/parsing for LLVM::ExtractElementOp. 1339 //===----------------------------------------------------------------------===// 1340 // Expects vector to be of wrapped LLVM vector type and position to be of 1341 // wrapped LLVM i32 type. 1342 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1343 Value vector, Value position, 1344 ArrayRef<NamedAttribute> attrs) { 1345 auto vectorType = vector.getType(); 1346 auto llvmType = LLVM::getVectorElementType(vectorType); 1347 build(b, result, llvmType, vector, position); 1348 result.addAttributes(attrs); 1349 } 1350 1351 void ExtractElementOp::print(OpAsmPrinter &p) { 1352 p << ' ' << getVector() << "[" << getPosition() << " : " 1353 << getPosition().getType() << "]"; 1354 p.printOptionalAttrDict((*this)->getAttrs()); 1355 p << " : " << getVector().getType(); 1356 } 1357 1358 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1359 // attribute-dict? `:` type 1360 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1361 OperationState &result) { 1362 SMLoc loc; 1363 OpAsmParser::UnresolvedOperand vector, position; 1364 Type type, positionType; 1365 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1366 parser.parseLSquare() || parser.parseOperand(position) || 1367 parser.parseColonType(positionType) || parser.parseRSquare() || 1368 parser.parseOptionalAttrDict(result.attributes) || 1369 parser.parseColonType(type) || 1370 parser.resolveOperand(vector, type, result.operands) || 1371 parser.resolveOperand(position, positionType, result.operands)) 1372 return failure(); 1373 if (!LLVM::isCompatibleVectorType(type)) 1374 return parser.emitError( 1375 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1376 result.addTypes(LLVM::getVectorElementType(type)); 1377 return success(); 1378 } 1379 1380 LogicalResult ExtractElementOp::verify() { 1381 Type vectorType = getVector().getType(); 1382 if (!LLVM::isCompatibleVectorType(vectorType)) 1383 return emitOpError("expected LLVM dialect-compatible vector type for " 1384 "operand #1, got") 1385 << vectorType; 1386 Type valueType = LLVM::getVectorElementType(vectorType); 1387 if (valueType != getRes().getType()) 1388 return emitOpError() << "Type mismatch: extracting from " << vectorType 1389 << " should produce " << valueType 1390 << " but this op returns " << getRes().getType(); 1391 return success(); 1392 } 1393 1394 //===----------------------------------------------------------------------===// 1395 // Printing/parsing for LLVM::ExtractValueOp. 1396 //===----------------------------------------------------------------------===// 1397 1398 void ExtractValueOp::print(OpAsmPrinter &p) { 1399 p << ' ' << getContainer() << getPosition(); 1400 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1401 p << " : " << getContainer().getType(); 1402 } 1403 1404 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1405 // `containerType`. Position is an integer array attribute where each value 1406 // is a zero-based position of the element in the aggregate type. Return the 1407 // resulting type wrapped in MLIR, or nullptr on error. 1408 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1409 Type containerType, 1410 ArrayAttr positionAttr, 1411 SMLoc attributeLoc, 1412 SMLoc typeLoc) { 1413 Type llvmType = containerType; 1414 if (!isCompatibleType(containerType)) 1415 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1416 1417 // Infer the element type from the structure type: iteratively step inside the 1418 // type by taking the element type, indexed by the position attribute for 1419 // structures. Check the position index before accessing, it is supposed to 1420 // be in bounds. 1421 for (Attribute subAttr : positionAttr) { 1422 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1423 if (!positionElementAttr) 1424 return parser.emitError(attributeLoc, 1425 "expected an array of integer literals"), 1426 nullptr; 1427 int position = positionElementAttr.getInt(); 1428 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1429 if (position < 0 || 1430 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1431 return parser.emitError(attributeLoc, "position out of bounds"), 1432 nullptr; 1433 llvmType = arrayType.getElementType(); 1434 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1435 if (position < 0 || 1436 static_cast<unsigned>(position) >= structType.getBody().size()) 1437 return parser.emitError(attributeLoc, "position out of bounds"), 1438 nullptr; 1439 llvmType = structType.getBody()[position]; 1440 } else { 1441 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1442 nullptr; 1443 } 1444 } 1445 return llvmType; 1446 } 1447 1448 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1449 // `containerType`. Returns null on failure. 1450 static Type getInsertExtractValueElementType(Type containerType, 1451 ArrayAttr positionAttr, 1452 Operation *op) { 1453 Type llvmType = containerType; 1454 if (!isCompatibleType(containerType)) { 1455 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1456 return {}; 1457 } 1458 1459 // Infer the element type from the structure type: iteratively step inside the 1460 // type by taking the element type, indexed by the position attribute for 1461 // structures. Check the position index before accessing, it is supposed to 1462 // be in bounds. 1463 for (Attribute subAttr : positionAttr) { 1464 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1465 if (!positionElementAttr) { 1466 op->emitOpError("expected an array of integer literals, got: ") 1467 << subAttr; 1468 return {}; 1469 } 1470 int position = positionElementAttr.getInt(); 1471 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1472 if (position < 0 || 1473 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1474 op->emitOpError("position out of bounds: ") << position; 1475 return {}; 1476 } 1477 llvmType = arrayType.getElementType(); 1478 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1479 if (position < 0 || 1480 static_cast<unsigned>(position) >= structType.getBody().size()) { 1481 op->emitOpError("position out of bounds") << position; 1482 return {}; 1483 } 1484 llvmType = structType.getBody()[position]; 1485 } else { 1486 op->emitOpError("expected LLVM IR structure/array type, got: ") 1487 << llvmType; 1488 return {}; 1489 } 1490 } 1491 return llvmType; 1492 } 1493 1494 // <operation> ::= `llvm.extractvalue` ssa-use 1495 // `[` integer-literal (`,` integer-literal)* `]` 1496 // attribute-dict? `:` type 1497 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1498 OpAsmParser::UnresolvedOperand container; 1499 Type containerType; 1500 ArrayAttr positionAttr; 1501 SMLoc attributeLoc, trailingTypeLoc; 1502 1503 if (parser.parseOperand(container) || 1504 parser.getCurrentLocation(&attributeLoc) || 1505 parser.parseAttribute(positionAttr, "position", result.attributes) || 1506 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1507 parser.getCurrentLocation(&trailingTypeLoc) || 1508 parser.parseType(containerType) || 1509 parser.resolveOperand(container, containerType, result.operands)) 1510 return failure(); 1511 1512 auto elementType = getInsertExtractValueElementType( 1513 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1514 if (!elementType) 1515 return failure(); 1516 1517 result.addTypes(elementType); 1518 return success(); 1519 } 1520 1521 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1522 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1523 OpFoldResult result = {}; 1524 while (insertValueOp) { 1525 if (getPosition() == insertValueOp.getPosition()) 1526 return insertValueOp.getValue(); 1527 unsigned min = 1528 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1529 // If one is fully prefix of the other, stop propagating back as it will 1530 // miss dependencies. For instance, %3 should not fold to %f0 in the 1531 // following example: 1532 // ``` 1533 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1534 // !llvm.array<4 x !llvm.array<4xf32>> 1535 // %2 = llvm.insertvalue %arr, %1[0] : 1536 // !llvm.array<4 x !llvm.array<4xf32>> 1537 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1538 // ``` 1539 if (getPosition().getValue().take_front(min) == 1540 insertValueOp.getPosition().getValue().take_front(min)) 1541 return result; 1542 1543 // If neither a prefix, nor the exact position, we can extract out of the 1544 // value being inserted into. Moreover, we can try again if that operand 1545 // is itself an insertvalue expression. 1546 getContainerMutable().assign(insertValueOp.getContainer()); 1547 result = getResult(); 1548 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1549 } 1550 return result; 1551 } 1552 1553 LogicalResult ExtractValueOp::verify() { 1554 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1555 getPositionAttr(), *this); 1556 if (!valueType) 1557 return failure(); 1558 1559 if (getRes().getType() != valueType) 1560 return emitOpError() << "Type mismatch: extracting from " 1561 << getContainer().getType() << " should produce " 1562 << valueType << " but this op returns " 1563 << getRes().getType(); 1564 return success(); 1565 } 1566 1567 //===----------------------------------------------------------------------===// 1568 // Printing/parsing for LLVM::InsertElementOp. 1569 //===----------------------------------------------------------------------===// 1570 1571 void InsertElementOp::print(OpAsmPrinter &p) { 1572 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1573 << getPosition().getType() << "]"; 1574 p.printOptionalAttrDict((*this)->getAttrs()); 1575 p << " : " << getVector().getType(); 1576 } 1577 1578 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1579 // attribute-dict? `:` type 1580 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1581 OperationState &result) { 1582 SMLoc loc; 1583 OpAsmParser::UnresolvedOperand vector, value, position; 1584 Type vectorType, positionType; 1585 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1586 parser.parseComma() || parser.parseOperand(vector) || 1587 parser.parseLSquare() || parser.parseOperand(position) || 1588 parser.parseColonType(positionType) || parser.parseRSquare() || 1589 parser.parseOptionalAttrDict(result.attributes) || 1590 parser.parseColonType(vectorType)) 1591 return failure(); 1592 1593 if (!LLVM::isCompatibleVectorType(vectorType)) 1594 return parser.emitError( 1595 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1596 Type valueType = LLVM::getVectorElementType(vectorType); 1597 if (!valueType) 1598 return failure(); 1599 1600 if (parser.resolveOperand(vector, vectorType, result.operands) || 1601 parser.resolveOperand(value, valueType, result.operands) || 1602 parser.resolveOperand(position, positionType, result.operands)) 1603 return failure(); 1604 1605 result.addTypes(vectorType); 1606 return success(); 1607 } 1608 1609 LogicalResult InsertElementOp::verify() { 1610 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1611 if (valueType != getValue().getType()) 1612 return emitOpError() << "Type mismatch: cannot insert " 1613 << getValue().getType() << " into " 1614 << getVector().getType(); 1615 return success(); 1616 } 1617 1618 //===----------------------------------------------------------------------===// 1619 // Printing/parsing for LLVM::InsertValueOp. 1620 //===----------------------------------------------------------------------===// 1621 1622 void InsertValueOp::print(OpAsmPrinter &p) { 1623 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1624 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1625 p << " : " << getContainer().getType(); 1626 } 1627 1628 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1629 // `[` integer-literal (`,` integer-literal)* `]` 1630 // attribute-dict? `:` type 1631 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1632 OpAsmParser::UnresolvedOperand container, value; 1633 Type containerType; 1634 ArrayAttr positionAttr; 1635 SMLoc attributeLoc, trailingTypeLoc; 1636 1637 if (parser.parseOperand(value) || parser.parseComma() || 1638 parser.parseOperand(container) || 1639 parser.getCurrentLocation(&attributeLoc) || 1640 parser.parseAttribute(positionAttr, "position", result.attributes) || 1641 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1642 parser.getCurrentLocation(&trailingTypeLoc) || 1643 parser.parseType(containerType)) 1644 return failure(); 1645 1646 auto valueType = getInsertExtractValueElementType( 1647 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1648 if (!valueType) 1649 return failure(); 1650 1651 if (parser.resolveOperand(container, containerType, result.operands) || 1652 parser.resolveOperand(value, valueType, result.operands)) 1653 return failure(); 1654 1655 result.addTypes(containerType); 1656 return success(); 1657 } 1658 1659 LogicalResult InsertValueOp::verify() { 1660 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1661 getPositionAttr(), *this); 1662 if (!valueType) 1663 return failure(); 1664 1665 if (getValue().getType() != valueType) 1666 return emitOpError() << "Type mismatch: cannot insert " 1667 << getValue().getType() << " into " 1668 << getContainer().getType(); 1669 1670 return success(); 1671 } 1672 1673 //===----------------------------------------------------------------------===// 1674 // Printing, parsing and verification for LLVM::ReturnOp. 1675 //===----------------------------------------------------------------------===// 1676 1677 LogicalResult ReturnOp::verify() { 1678 if (getNumOperands() > 1) 1679 return emitOpError("expected at most 1 operand"); 1680 1681 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1682 Type expectedType = parent.getFunctionType().getReturnType(); 1683 if (expectedType.isa<LLVMVoidType>()) { 1684 if (getNumOperands() == 0) 1685 return success(); 1686 InFlightDiagnostic diag = emitOpError("expected no operands"); 1687 diag.attachNote(parent->getLoc()) << "when returning from function"; 1688 return diag; 1689 } 1690 if (getNumOperands() == 0) { 1691 if (expectedType.isa<LLVMVoidType>()) 1692 return success(); 1693 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1694 diag.attachNote(parent->getLoc()) << "when returning from function"; 1695 return diag; 1696 } 1697 if (expectedType != getOperand(0).getType()) { 1698 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1699 diag.attachNote(parent->getLoc()) << "when returning from function"; 1700 return diag; 1701 } 1702 } 1703 return success(); 1704 } 1705 1706 //===----------------------------------------------------------------------===// 1707 // ResumeOp 1708 //===----------------------------------------------------------------------===// 1709 1710 LogicalResult ResumeOp::verify() { 1711 if (!getValue().getDefiningOp<LandingpadOp>()) 1712 return emitOpError("expects landingpad value as operand"); 1713 // No check for personality of function - landingpad op verifies it. 1714 return success(); 1715 } 1716 1717 //===----------------------------------------------------------------------===// 1718 // Verifier for LLVM::AddressOfOp. 1719 //===----------------------------------------------------------------------===// 1720 1721 template <typename OpTy> 1722 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1723 Operation *module = parent; 1724 while (module && !satisfiesLLVMModule(module)) 1725 module = module->getParentOp(); 1726 assert(module && "unexpected operation outside of a module"); 1727 return dyn_cast_or_null<OpTy>( 1728 mlir::SymbolTable::lookupSymbolIn(module, name)); 1729 } 1730 1731 GlobalOp AddressOfOp::getGlobal() { 1732 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1733 getGlobalName()); 1734 } 1735 1736 LLVMFuncOp AddressOfOp::getFunction() { 1737 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1738 getGlobalName()); 1739 } 1740 1741 LogicalResult AddressOfOp::verify() { 1742 auto global = getGlobal(); 1743 auto function = getFunction(); 1744 if (!global && !function) 1745 return emitOpError( 1746 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1747 1748 LLVMPointerType type = getType(); 1749 if (global && global.getAddrSpace() != type.getAddressSpace()) 1750 return emitOpError("pointer address space must match address space of the " 1751 "referenced global"); 1752 1753 if (type.isOpaque()) 1754 return success(); 1755 1756 if (global && type.getElementType() != global.getType()) 1757 return emitOpError( 1758 "the type must be a pointer to the type of the referenced global"); 1759 1760 if (function && type.getElementType() != function.getFunctionType()) 1761 return emitOpError( 1762 "the type must be a pointer to the type of the referenced function"); 1763 1764 return success(); 1765 } 1766 1767 //===----------------------------------------------------------------------===// 1768 // Builder, printer and verifier for LLVM::GlobalOp. 1769 //===----------------------------------------------------------------------===// 1770 1771 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1772 bool isConstant, Linkage linkage, StringRef name, 1773 Attribute value, uint64_t alignment, unsigned addrSpace, 1774 bool dsoLocal, bool threadLocal, 1775 ArrayRef<NamedAttribute> attrs) { 1776 result.addAttribute(getSymNameAttrName(result.name), 1777 builder.getStringAttr(name)); 1778 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 1779 if (isConstant) 1780 result.addAttribute(getConstantAttrName(result.name), 1781 builder.getUnitAttr()); 1782 if (value) 1783 result.addAttribute(getValueAttrName(result.name), value); 1784 if (dsoLocal) 1785 result.addAttribute(getDsoLocalAttrName(result.name), 1786 builder.getUnitAttr()); 1787 if (threadLocal) 1788 result.addAttribute(getThreadLocal_AttrName(result.name), 1789 builder.getUnitAttr()); 1790 1791 // Only add an alignment attribute if the "alignment" input 1792 // is different from 0. The value must also be a power of two, but 1793 // this is tested in GlobalOp::verify, not here. 1794 if (alignment != 0) 1795 result.addAttribute(getAlignmentAttrName(result.name), 1796 builder.getI64IntegerAttr(alignment)); 1797 1798 result.addAttribute(getLinkageAttrName(result.name), 1799 LinkageAttr::get(builder.getContext(), linkage)); 1800 if (addrSpace != 0) 1801 result.addAttribute(getAddrSpaceAttrName(result.name), 1802 builder.getI32IntegerAttr(addrSpace)); 1803 result.attributes.append(attrs.begin(), attrs.end()); 1804 result.addRegion(); 1805 } 1806 1807 void GlobalOp::print(OpAsmPrinter &p) { 1808 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1809 if (auto unnamedAddr = getUnnamedAddr()) { 1810 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1811 if (!str.empty()) 1812 p << str << ' '; 1813 } 1814 if (getThreadLocal_()) 1815 p << "thread_local "; 1816 if (getConstant()) 1817 p << "constant "; 1818 p.printSymbolName(getSymName()); 1819 p << '('; 1820 if (auto value = getValueOrNull()) 1821 p.printAttribute(value); 1822 p << ')'; 1823 // Note that the alignment attribute is printed using the 1824 // default syntax here, even though it is an inherent attribute 1825 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1826 p.printOptionalAttrDict( 1827 (*this)->getAttrs(), 1828 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), 1829 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), 1830 getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); 1831 1832 // Print the trailing type unless it's a string global. 1833 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1834 return; 1835 p << " : " << getType(); 1836 1837 Region &initializer = getInitializerRegion(); 1838 if (!initializer.empty()) { 1839 p << ' '; 1840 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1841 } 1842 } 1843 1844 // Parses one of the keywords provided in the list `keywords` and returns the 1845 // position of the parsed keyword in the list. If none of the keywords from the 1846 // list is parsed, returns -1. 1847 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1848 ArrayRef<StringRef> keywords) { 1849 for (const auto &en : llvm::enumerate(keywords)) { 1850 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1851 return en.index(); 1852 } 1853 return -1; 1854 } 1855 1856 namespace { 1857 template <typename Ty> 1858 struct EnumTraits {}; 1859 1860 #define REGISTER_ENUM_TYPE(Ty) \ 1861 template <> \ 1862 struct EnumTraits<Ty> { \ 1863 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1864 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1865 } 1866 1867 REGISTER_ENUM_TYPE(Linkage); 1868 REGISTER_ENUM_TYPE(UnnamedAddr); 1869 REGISTER_ENUM_TYPE(CConv); 1870 } // namespace 1871 1872 /// Parse an enum from the keyword, or default to the provided default value. 1873 /// The return type is the enum type by default, unless overriden with the 1874 /// second template argument. 1875 template <typename EnumTy, typename RetTy = EnumTy> 1876 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1877 OperationState &result, 1878 EnumTy defaultValue) { 1879 SmallVector<StringRef, 10> names; 1880 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1881 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1882 1883 int index = parseOptionalKeywordAlternative(parser, names); 1884 if (index == -1) 1885 return static_cast<RetTy>(defaultValue); 1886 return static_cast<RetTy>(index); 1887 } 1888 1889 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1890 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1891 // align ::= `align` `=` UINT64 1892 // 1893 // The type can be omitted for string attributes, in which case it will be 1894 // inferred from the value of the string as [strlen(value) x i8]. 1895 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1896 MLIRContext *ctx = parser.getContext(); 1897 // Parse optional linkage, default to External. 1898 result.addAttribute(getLinkageAttrName(result.name), 1899 LLVM::LinkageAttr::get( 1900 ctx, parseOptionalLLVMKeyword<Linkage>( 1901 parser, result, LLVM::Linkage::External))); 1902 1903 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 1904 result.addAttribute(getThreadLocal_AttrName(result.name), 1905 parser.getBuilder().getUnitAttr()); 1906 1907 // Parse optional UnnamedAddr, default to None. 1908 result.addAttribute(getUnnamedAddrAttrName(result.name), 1909 parser.getBuilder().getI64IntegerAttr( 1910 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1911 parser, result, LLVM::UnnamedAddr::None))); 1912 1913 if (succeeded(parser.parseOptionalKeyword("constant"))) 1914 result.addAttribute(getConstantAttrName(result.name), 1915 parser.getBuilder().getUnitAttr()); 1916 1917 StringAttr name; 1918 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 1919 result.attributes) || 1920 parser.parseLParen()) 1921 return failure(); 1922 1923 Attribute value; 1924 if (parser.parseOptionalRParen()) { 1925 if (parser.parseAttribute(value, getValueAttrName(result.name), 1926 result.attributes) || 1927 parser.parseRParen()) 1928 return failure(); 1929 } 1930 1931 SmallVector<Type, 1> types; 1932 if (parser.parseOptionalAttrDict(result.attributes) || 1933 parser.parseOptionalColonTypeList(types)) 1934 return failure(); 1935 1936 if (types.size() > 1) 1937 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1938 1939 Region &initRegion = *result.addRegion(); 1940 if (types.empty()) { 1941 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1942 MLIRContext *context = parser.getContext(); 1943 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1944 strAttr.getValue().size()); 1945 types.push_back(arrayType); 1946 } else { 1947 return parser.emitError(parser.getNameLoc(), 1948 "type can only be omitted for string globals"); 1949 } 1950 } else { 1951 OptionalParseResult parseResult = 1952 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1953 /*argTypes=*/{}); 1954 if (parseResult.hasValue() && failed(*parseResult)) 1955 return failure(); 1956 } 1957 1958 result.addAttribute(getGlobalTypeAttrName(result.name), 1959 TypeAttr::get(types[0])); 1960 return success(); 1961 } 1962 1963 static bool isZeroAttribute(Attribute value) { 1964 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1965 return intValue.getValue().isNullValue(); 1966 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1967 return fpValue.getValue().isZero(); 1968 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1969 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1970 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1971 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1972 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1973 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1974 return false; 1975 } 1976 1977 LogicalResult GlobalOp::verify() { 1978 if (!LLVMPointerType::isValidElementType(getType())) 1979 return emitOpError( 1980 "expects type to be a valid element type for an LLVM pointer"); 1981 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1982 return emitOpError("must appear at the module level"); 1983 1984 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1985 auto type = getType().dyn_cast<LLVMArrayType>(); 1986 IntegerType elementType = 1987 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1988 if (!elementType || elementType.getWidth() != 8 || 1989 type.getNumElements() != strAttr.getValue().size()) 1990 return emitOpError( 1991 "requires an i8 array type of the length equal to that of the string " 1992 "attribute"); 1993 } 1994 1995 if (getLinkage() == Linkage::Common) { 1996 if (Attribute value = getValueOrNull()) { 1997 if (!isZeroAttribute(value)) { 1998 return emitOpError() 1999 << "expected zero value for '" 2000 << stringifyLinkage(Linkage::Common) << "' linkage"; 2001 } 2002 } 2003 } 2004 2005 if (getLinkage() == Linkage::Appending) { 2006 if (!getType().isa<LLVMArrayType>()) { 2007 return emitOpError() << "expected array type for '" 2008 << stringifyLinkage(Linkage::Appending) 2009 << "' linkage"; 2010 } 2011 } 2012 2013 Optional<uint64_t> alignAttr = getAlignment(); 2014 if (alignAttr.hasValue()) { 2015 uint64_t value = alignAttr.getValue(); 2016 if (!llvm::isPowerOf2_64(value)) 2017 return emitError() << "alignment attribute is not a power of 2"; 2018 } 2019 2020 return success(); 2021 } 2022 2023 LogicalResult GlobalOp::verifyRegions() { 2024 if (Block *b = getInitializerBlock()) { 2025 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 2026 if (ret.operand_type_begin() == ret.operand_type_end()) 2027 return emitOpError("initializer region cannot return void"); 2028 if (*ret.operand_type_begin() != getType()) 2029 return emitOpError("initializer region type ") 2030 << *ret.operand_type_begin() << " does not match global type " 2031 << getType(); 2032 2033 for (Operation &op : *b) { 2034 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 2035 if (!iface || !iface.hasNoEffect()) 2036 return op.emitError() 2037 << "ops with side effects not allowed in global initializers"; 2038 } 2039 2040 if (getValueOrNull()) 2041 return emitOpError("cannot have both initializer value and region"); 2042 } 2043 2044 return success(); 2045 } 2046 2047 //===----------------------------------------------------------------------===// 2048 // LLVM::GlobalCtorsOp 2049 //===----------------------------------------------------------------------===// 2050 2051 LogicalResult 2052 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2053 for (Attribute ctor : getCtors()) { 2054 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 2055 symbolTable))) 2056 return failure(); 2057 } 2058 return success(); 2059 } 2060 2061 LogicalResult GlobalCtorsOp::verify() { 2062 if (getCtors().size() != getPriorities().size()) 2063 return emitError( 2064 "mismatch between the number of ctors and the number of priorities"); 2065 return success(); 2066 } 2067 2068 //===----------------------------------------------------------------------===// 2069 // LLVM::GlobalDtorsOp 2070 //===----------------------------------------------------------------------===// 2071 2072 LogicalResult 2073 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2074 for (Attribute dtor : getDtors()) { 2075 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 2076 symbolTable))) 2077 return failure(); 2078 } 2079 return success(); 2080 } 2081 2082 LogicalResult GlobalDtorsOp::verify() { 2083 if (getDtors().size() != getPriorities().size()) 2084 return emitError( 2085 "mismatch between the number of dtors and the number of priorities"); 2086 return success(); 2087 } 2088 2089 //===----------------------------------------------------------------------===// 2090 // Printing/parsing for LLVM::ShuffleVectorOp. 2091 //===----------------------------------------------------------------------===// 2092 // Expects vector to be of wrapped LLVM vector type and position to be of 2093 // wrapped LLVM i32 type. 2094 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 2095 Value v1, Value v2, ArrayAttr mask, 2096 ArrayRef<NamedAttribute> attrs) { 2097 auto containerType = v1.getType(); 2098 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), 2099 mask.size(), 2100 LLVM::isScalableVectorType(containerType)); 2101 build(b, result, vType, v1, v2, mask); 2102 result.addAttributes(attrs); 2103 } 2104 2105 void ShuffleVectorOp::print(OpAsmPrinter &p) { 2106 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 2107 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 2108 p << " : " << getV1().getType() << ", " << getV2().getType(); 2109 } 2110 2111 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 2112 // `[` integer-literal (`,` integer-literal)* `]` 2113 // attribute-dict? `:` type 2114 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 2115 OperationState &result) { 2116 SMLoc loc; 2117 OpAsmParser::UnresolvedOperand v1, v2; 2118 ArrayAttr maskAttr; 2119 Type typeV1, typeV2; 2120 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 2121 parser.parseComma() || parser.parseOperand(v2) || 2122 parser.parseAttribute(maskAttr, "mask", result.attributes) || 2123 parser.parseOptionalAttrDict(result.attributes) || 2124 parser.parseColonType(typeV1) || parser.parseComma() || 2125 parser.parseType(typeV2) || 2126 parser.resolveOperand(v1, typeV1, result.operands) || 2127 parser.resolveOperand(v2, typeV2, result.operands)) 2128 return failure(); 2129 if (!LLVM::isCompatibleVectorType(typeV1)) 2130 return parser.emitError( 2131 loc, "expected LLVM IR dialect vector type for operand #1"); 2132 auto vType = 2133 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 2134 LLVM::isScalableVectorType(typeV1)); 2135 result.addTypes(vType); 2136 return success(); 2137 } 2138 2139 LogicalResult ShuffleVectorOp::verify() { 2140 Type type1 = getV1().getType(); 2141 Type type2 = getV2().getType(); 2142 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 2143 return emitOpError("expected matching LLVM IR Dialect element types"); 2144 if (LLVM::isScalableVectorType(type1)) 2145 if (llvm::any_of(getMask(), [](Attribute attr) { 2146 return attr.cast<IntegerAttr>().getInt() != 0; 2147 })) 2148 return emitOpError("expected a splat operation for scalable vectors"); 2149 return success(); 2150 } 2151 2152 //===----------------------------------------------------------------------===// 2153 // Implementations for LLVM::LLVMFuncOp. 2154 //===----------------------------------------------------------------------===// 2155 2156 // Add the entry block to the function. 2157 Block *LLVMFuncOp::addEntryBlock() { 2158 assert(empty() && "function already has an entry block"); 2159 2160 auto *entry = new Block; 2161 push_back(entry); 2162 2163 // FIXME: Allow passing in proper locations for the entry arguments. 2164 LLVMFunctionType type = getFunctionType(); 2165 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2166 entry->addArgument(type.getParamType(i), getLoc()); 2167 return entry; 2168 } 2169 2170 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2171 StringRef name, Type type, LLVM::Linkage linkage, 2172 bool dsoLocal, CConv cconv, 2173 ArrayRef<NamedAttribute> attrs, 2174 ArrayRef<DictionaryAttr> argAttrs) { 2175 result.addRegion(); 2176 result.addAttribute(SymbolTable::getSymbolAttrName(), 2177 builder.getStringAttr(name)); 2178 result.addAttribute(getFunctionTypeAttrName(result.name), 2179 TypeAttr::get(type)); 2180 result.addAttribute(getLinkageAttrName(result.name), 2181 LinkageAttr::get(builder.getContext(), linkage)); 2182 result.addAttribute(getCConvAttrName(result.name), 2183 CConvAttr::get(builder.getContext(), cconv)); 2184 result.attributes.append(attrs.begin(), attrs.end()); 2185 if (dsoLocal) 2186 result.addAttribute("dso_local", builder.getUnitAttr()); 2187 if (argAttrs.empty()) 2188 return; 2189 2190 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 2191 "expected as many argument attribute lists as arguments"); 2192 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 2193 /*resultAttrs=*/llvm::None); 2194 } 2195 2196 // Builds an LLVM function type from the given lists of input and output types. 2197 // Returns a null type if any of the types provided are non-LLVM types, or if 2198 // there is more than one output type. 2199 static Type 2200 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2201 ArrayRef<Type> outputs, 2202 function_interface_impl::VariadicFlag variadicFlag) { 2203 Builder &b = parser.getBuilder(); 2204 if (outputs.size() > 1) { 2205 parser.emitError(loc, "failed to construct function type: expected zero or " 2206 "one function result"); 2207 return {}; 2208 } 2209 2210 // Convert inputs to LLVM types, exit early on error. 2211 SmallVector<Type, 4> llvmInputs; 2212 for (auto t : inputs) { 2213 if (!isCompatibleType(t)) { 2214 parser.emitError(loc, "failed to construct function type: expected LLVM " 2215 "type for function arguments"); 2216 return {}; 2217 } 2218 llvmInputs.push_back(t); 2219 } 2220 2221 // No output is denoted as "void" in LLVM type system. 2222 Type llvmOutput = 2223 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2224 if (!isCompatibleType(llvmOutput)) { 2225 parser.emitError(loc, "failed to construct function type: expected LLVM " 2226 "type for function results") 2227 << llvmOutput; 2228 return {}; 2229 } 2230 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2231 variadicFlag.isVariadic()); 2232 } 2233 2234 // Parses an LLVM function. 2235 // 2236 // operation ::= `llvm.func` linkage? cconv? function-signature 2237 // function-attributes? 2238 // function-body 2239 // 2240 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2241 // Default to external linkage if no keyword is provided. 2242 result.addAttribute( 2243 getLinkageAttrName(result.name), 2244 LinkageAttr::get(parser.getContext(), 2245 parseOptionalLLVMKeyword<Linkage>( 2246 parser, result, LLVM::Linkage::External))); 2247 2248 // Default to C Calling Convention if no keyword is provided. 2249 result.addAttribute( 2250 getCConvAttrName(result.name), 2251 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 2252 parser, result, LLVM::CConv::C))); 2253 2254 StringAttr nameAttr; 2255 SmallVector<OpAsmParser::Argument> entryArgs; 2256 SmallVector<DictionaryAttr> resultAttrs; 2257 SmallVector<Type> resultTypes; 2258 bool isVariadic; 2259 2260 auto signatureLocation = parser.getCurrentLocation(); 2261 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2262 result.attributes) || 2263 function_interface_impl::parseFunctionSignature( 2264 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, 2265 resultAttrs)) 2266 return failure(); 2267 2268 SmallVector<Type> argTypes; 2269 for (auto &arg : entryArgs) 2270 argTypes.push_back(arg.type); 2271 auto type = 2272 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2273 function_interface_impl::VariadicFlag(isVariadic)); 2274 if (!type) 2275 return failure(); 2276 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2277 TypeAttr::get(type)); 2278 2279 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2280 return failure(); 2281 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2282 entryArgs, resultAttrs); 2283 2284 auto *body = result.addRegion(); 2285 OptionalParseResult parseResult = 2286 parser.parseOptionalRegion(*body, entryArgs); 2287 return failure(parseResult.hasValue() && failed(*parseResult)); 2288 } 2289 2290 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2291 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2292 // the external linkage since it is the default value. 2293 void LLVMFuncOp::print(OpAsmPrinter &p) { 2294 p << ' '; 2295 if (getLinkage() != LLVM::Linkage::External) 2296 p << stringifyLinkage(getLinkage()) << ' '; 2297 if (getCConv() != LLVM::CConv::C) 2298 p << stringifyCConv(getCConv()) << ' '; 2299 2300 p.printSymbolName(getName()); 2301 2302 LLVMFunctionType fnType = getFunctionType(); 2303 SmallVector<Type, 8> argTypes; 2304 SmallVector<Type, 1> resTypes; 2305 argTypes.reserve(fnType.getNumParams()); 2306 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2307 argTypes.push_back(fnType.getParamType(i)); 2308 2309 Type returnType = fnType.getReturnType(); 2310 if (!returnType.isa<LLVMVoidType>()) 2311 resTypes.push_back(returnType); 2312 2313 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2314 isVarArg(), resTypes); 2315 function_interface_impl::printFunctionAttributes( 2316 p, *this, argTypes.size(), resTypes.size(), 2317 {getLinkageAttrName(), getCConvAttrName()}); 2318 2319 // Print the body if this is not an external function. 2320 Region &body = getBody(); 2321 if (!body.empty()) { 2322 p << ' '; 2323 p.printRegion(body, /*printEntryBlockArgs=*/false, 2324 /*printBlockTerminators=*/true); 2325 } 2326 } 2327 2328 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2329 // - functions don't have 'common' linkage 2330 // - external functions have 'external' or 'extern_weak' linkage; 2331 // - vararg is (currently) only supported for external functions; 2332 LogicalResult LLVMFuncOp::verify() { 2333 if (getLinkage() == LLVM::Linkage::Common) 2334 return emitOpError() << "functions cannot have '" 2335 << stringifyLinkage(LLVM::Linkage::Common) 2336 << "' linkage"; 2337 2338 // Check to see if this function has a void return with a result attribute to 2339 // it. It isn't clear what semantics we would assign to that. 2340 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2341 !getResultAttrs(0).empty()) { 2342 return emitOpError() 2343 << "cannot attach result attributes to functions with a void return"; 2344 } 2345 2346 if (isExternal()) { 2347 if (getLinkage() != LLVM::Linkage::External && 2348 getLinkage() != LLVM::Linkage::ExternWeak) 2349 return emitOpError() << "external functions must have '" 2350 << stringifyLinkage(LLVM::Linkage::External) 2351 << "' or '" 2352 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2353 << "' linkage"; 2354 return success(); 2355 } 2356 2357 return success(); 2358 } 2359 2360 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2361 /// - entry block arguments are of LLVM types. 2362 LogicalResult LLVMFuncOp::verifyRegions() { 2363 if (isExternal()) 2364 return success(); 2365 2366 unsigned numArguments = getFunctionType().getNumParams(); 2367 Block &entryBlock = front(); 2368 for (unsigned i = 0; i < numArguments; ++i) { 2369 Type argType = entryBlock.getArgument(i).getType(); 2370 if (!isCompatibleType(argType)) 2371 return emitOpError("entry block argument #") 2372 << i << " is not of LLVM type"; 2373 } 2374 2375 return success(); 2376 } 2377 2378 //===----------------------------------------------------------------------===// 2379 // Verification for LLVM::ConstantOp. 2380 //===----------------------------------------------------------------------===// 2381 2382 LogicalResult LLVM::ConstantOp::verify() { 2383 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2384 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2385 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2386 !arrayType.getElementType().isInteger(8)) { 2387 return emitOpError() << "expected array type of " 2388 << sAttr.getValue().size() 2389 << " i8 elements for the string constant"; 2390 } 2391 return success(); 2392 } 2393 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2394 if (structType.getBody().size() != 2 || 2395 structType.getBody()[0] != structType.getBody()[1]) { 2396 return emitError() << "expected struct type with two elements of the " 2397 "same type, the type of a complex constant"; 2398 } 2399 2400 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2401 if (!arrayAttr || arrayAttr.size() != 2 || 2402 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2403 return emitOpError() << "expected array attribute with two elements, " 2404 "representing a complex constant"; 2405 } 2406 2407 Type elementType = structType.getBody()[0]; 2408 if (!elementType 2409 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2410 return emitError() 2411 << "expected struct element types to be floating point type or " 2412 "integer type"; 2413 } 2414 return success(); 2415 } 2416 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2417 return emitOpError() 2418 << "only supports integer, float, string or elements attributes"; 2419 return success(); 2420 } 2421 2422 // Constant op constant-folds to its value. 2423 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2424 2425 //===----------------------------------------------------------------------===// 2426 // Utility functions for parsing atomic ops 2427 //===----------------------------------------------------------------------===// 2428 2429 // Helper function to parse a keyword into the specified attribute named by 2430 // `attrName`. The keyword must match one of the string values defined by the 2431 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2432 // state. 2433 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2434 StringRef attrName) { 2435 SMLoc loc; 2436 StringRef keyword; 2437 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2438 return failure(); 2439 2440 // Replace the keyword `keyword` with an integer attribute. 2441 auto kind = symbolizeAtomicBinOp(keyword); 2442 if (!kind) { 2443 return parser.emitError(loc) 2444 << "'" << keyword << "' is an incorrect value of the '" << attrName 2445 << "' attribute"; 2446 } 2447 2448 auto value = static_cast<int64_t>(*kind); 2449 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2450 result.addAttribute(attrName, attr); 2451 2452 return success(); 2453 } 2454 2455 // Helper function to parse a keyword into the specified attribute named by 2456 // `attrName`. The keyword must match one of the string values defined by the 2457 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2458 // state. 2459 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2460 OperationState &result, 2461 StringRef attrName) { 2462 SMLoc loc; 2463 StringRef ordering; 2464 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2465 return failure(); 2466 2467 // Replace the keyword `ordering` with an integer attribute. 2468 auto kind = symbolizeAtomicOrdering(ordering); 2469 if (!kind) { 2470 return parser.emitError(loc) 2471 << "'" << ordering << "' is an incorrect value of the '" << attrName 2472 << "' attribute"; 2473 } 2474 2475 auto value = static_cast<int64_t>(*kind); 2476 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2477 result.addAttribute(attrName, attr); 2478 2479 return success(); 2480 } 2481 2482 //===----------------------------------------------------------------------===// 2483 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2484 //===----------------------------------------------------------------------===// 2485 2486 void AtomicRMWOp::print(OpAsmPrinter &p) { 2487 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2488 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2489 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2490 p << " : " << getRes().getType(); 2491 } 2492 2493 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2494 // attribute-dict? `:` type 2495 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2496 Type type; 2497 OpAsmParser::UnresolvedOperand ptr, val; 2498 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2499 parser.parseComma() || parser.parseOperand(val) || 2500 parseAtomicOrdering(parser, result, "ordering") || 2501 parser.parseOptionalAttrDict(result.attributes) || 2502 parser.parseColonType(type) || 2503 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2504 result.operands) || 2505 parser.resolveOperand(val, type, result.operands)) 2506 return failure(); 2507 2508 result.addTypes(type); 2509 return success(); 2510 } 2511 2512 LogicalResult AtomicRMWOp::verify() { 2513 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2514 auto valType = getVal().getType(); 2515 if (valType != ptrType.getElementType()) 2516 return emitOpError("expected LLVM IR element type for operand #0 to " 2517 "match type for operand #1"); 2518 auto resType = getRes().getType(); 2519 if (resType != valType) 2520 return emitOpError( 2521 "expected LLVM IR result type to match type for operand #1"); 2522 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2523 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2524 return emitOpError("expected LLVM IR floating point type"); 2525 } else if (getBinOp() == AtomicBinOp::xchg) { 2526 auto intType = valType.dyn_cast<IntegerType>(); 2527 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2528 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2529 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2530 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2531 !valType.isa<Float64Type>()) 2532 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2533 } else { 2534 auto intType = valType.dyn_cast<IntegerType>(); 2535 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2536 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2537 intBitWidth != 64) 2538 return emitOpError("expected LLVM IR integer type"); 2539 } 2540 2541 if (static_cast<unsigned>(getOrdering()) < 2542 static_cast<unsigned>(AtomicOrdering::monotonic)) 2543 return emitOpError() << "expected at least '" 2544 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2545 << "' ordering"; 2546 2547 return success(); 2548 } 2549 2550 //===----------------------------------------------------------------------===// 2551 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2552 //===----------------------------------------------------------------------===// 2553 2554 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2555 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2556 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2557 << stringifyAtomicOrdering(getFailureOrdering()); 2558 p.printOptionalAttrDict((*this)->getAttrs(), 2559 {"success_ordering", "failure_ordering"}); 2560 p << " : " << getVal().getType(); 2561 } 2562 2563 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2564 // keyword keyword attribute-dict? `:` type 2565 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2566 OperationState &result) { 2567 auto &builder = parser.getBuilder(); 2568 Type type; 2569 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2570 if (parser.parseOperand(ptr) || parser.parseComma() || 2571 parser.parseOperand(cmp) || parser.parseComma() || 2572 parser.parseOperand(val) || 2573 parseAtomicOrdering(parser, result, "success_ordering") || 2574 parseAtomicOrdering(parser, result, "failure_ordering") || 2575 parser.parseOptionalAttrDict(result.attributes) || 2576 parser.parseColonType(type) || 2577 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2578 result.operands) || 2579 parser.resolveOperand(cmp, type, result.operands) || 2580 parser.resolveOperand(val, type, result.operands)) 2581 return failure(); 2582 2583 auto boolType = IntegerType::get(builder.getContext(), 1); 2584 auto resultType = 2585 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2586 result.addTypes(resultType); 2587 2588 return success(); 2589 } 2590 2591 LogicalResult AtomicCmpXchgOp::verify() { 2592 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2593 if (!ptrType) 2594 return emitOpError("expected LLVM IR pointer type for operand #0"); 2595 auto cmpType = getCmp().getType(); 2596 auto valType = getVal().getType(); 2597 if (cmpType != ptrType.getElementType() || cmpType != valType) 2598 return emitOpError("expected LLVM IR element type for operand #0 to " 2599 "match type for all other operands"); 2600 auto intType = valType.dyn_cast<IntegerType>(); 2601 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2602 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2603 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2604 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2605 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2606 return emitOpError("unexpected LLVM IR type"); 2607 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2608 getFailureOrdering() < AtomicOrdering::monotonic) 2609 return emitOpError("ordering must be at least 'monotonic'"); 2610 if (getFailureOrdering() == AtomicOrdering::release || 2611 getFailureOrdering() == AtomicOrdering::acq_rel) 2612 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2613 return success(); 2614 } 2615 2616 //===----------------------------------------------------------------------===// 2617 // Printer, parser and verifier for LLVM::FenceOp. 2618 //===----------------------------------------------------------------------===// 2619 2620 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2621 // attribute-dict? 2622 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2623 StringAttr sScope; 2624 StringRef syncscopeKeyword = "syncscope"; 2625 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2626 if (parser.parseLParen() || 2627 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2628 parser.parseRParen()) 2629 return failure(); 2630 } else { 2631 result.addAttribute(syncscopeKeyword, 2632 parser.getBuilder().getStringAttr("")); 2633 } 2634 if (parseAtomicOrdering(parser, result, "ordering") || 2635 parser.parseOptionalAttrDict(result.attributes)) 2636 return failure(); 2637 return success(); 2638 } 2639 2640 void FenceOp::print(OpAsmPrinter &p) { 2641 StringRef syncscopeKeyword = "syncscope"; 2642 p << ' '; 2643 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2644 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2645 p << stringifyAtomicOrdering(getOrdering()); 2646 } 2647 2648 LogicalResult FenceOp::verify() { 2649 if (getOrdering() == AtomicOrdering::not_atomic || 2650 getOrdering() == AtomicOrdering::unordered || 2651 getOrdering() == AtomicOrdering::monotonic) 2652 return emitOpError("can be given only acquire, release, acq_rel, " 2653 "and seq_cst orderings"); 2654 return success(); 2655 } 2656 2657 //===----------------------------------------------------------------------===// 2658 // Folder for LLVM::BitcastOp 2659 //===----------------------------------------------------------------------===// 2660 2661 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2662 // bitcast(x : T0, T0) -> x 2663 if (getArg().getType() == getType()) 2664 return getArg(); 2665 // bitcast(bitcast(x : T0, T1), T0) -> x 2666 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2667 if (prev.getArg().getType() == getType()) 2668 return prev.getArg(); 2669 return {}; 2670 } 2671 2672 //===----------------------------------------------------------------------===// 2673 // Folder for LLVM::AddrSpaceCastOp 2674 //===----------------------------------------------------------------------===// 2675 2676 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2677 // addrcast(x : T0, T0) -> x 2678 if (getArg().getType() == getType()) 2679 return getArg(); 2680 // addrcast(addrcast(x : T0, T1), T0) -> x 2681 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2682 if (prev.getArg().getType() == getType()) 2683 return prev.getArg(); 2684 return {}; 2685 } 2686 2687 //===----------------------------------------------------------------------===// 2688 // Folder for LLVM::GEPOp 2689 //===----------------------------------------------------------------------===// 2690 2691 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2692 // gep %x:T, 0 -> %x 2693 if (getBase().getType() == getType() && getIndices().size() == 1 && 2694 matchPattern(getIndices()[0], m_Zero())) 2695 return getBase(); 2696 return {}; 2697 } 2698 2699 //===----------------------------------------------------------------------===// 2700 // LLVMDialect initialization, type parsing, and registration. 2701 //===----------------------------------------------------------------------===// 2702 2703 void LLVMDialect::initialize() { 2704 addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>(); 2705 2706 // clang-format off 2707 addTypes<LLVMVoidType, 2708 LLVMPPCFP128Type, 2709 LLVMX86MMXType, 2710 LLVMTokenType, 2711 LLVMLabelType, 2712 LLVMMetadataType, 2713 LLVMFunctionType, 2714 LLVMPointerType, 2715 LLVMFixedVectorType, 2716 LLVMScalableVectorType, 2717 LLVMArrayType, 2718 LLVMStructType>(); 2719 // clang-format on 2720 addOperations< 2721 #define GET_OP_LIST 2722 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2723 , 2724 #define GET_OP_LIST 2725 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 2726 >(); 2727 2728 // Support unknown operations because not all LLVM operations are registered. 2729 allowUnknownOperations(); 2730 } 2731 2732 #define GET_OP_CLASSES 2733 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2734 2735 /// Parse a type registered to this dialect. 2736 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2737 return detail::parseType(parser); 2738 } 2739 2740 /// Print a type registered to this dialect. 2741 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2742 return detail::printType(type, os); 2743 } 2744 2745 LogicalResult LLVMDialect::verifyDataLayoutString( 2746 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2747 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2748 llvm::DataLayout::parse(descr); 2749 if (maybeDataLayout) 2750 return success(); 2751 2752 std::string message; 2753 llvm::raw_string_ostream messageStream(message); 2754 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2755 reportError("invalid data layout descriptor: " + messageStream.str()); 2756 return failure(); 2757 } 2758 2759 /// Verify LLVM dialect attributes. 2760 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2761 NamedAttribute attr) { 2762 // If the `llvm.loop` attribute is present, enforce the following structure, 2763 // which the module translation can assume. 2764 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2765 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2766 if (!loopAttr) 2767 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2768 << "' to be a dictionary attribute"; 2769 Optional<NamedAttribute> parallelAccessGroup = 2770 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2771 if (parallelAccessGroup) { 2772 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2773 if (!accessGroups) 2774 return op->emitOpError() 2775 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2776 << "' to be an array attribute"; 2777 for (Attribute attr : accessGroups) { 2778 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2779 if (!accessGroupRef) 2780 return op->emitOpError() 2781 << "expected '" << attr << "' to be a symbol reference"; 2782 StringAttr metadataName = accessGroupRef.getRootReference(); 2783 auto metadataOp = 2784 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2785 op->getParentOp(), metadataName); 2786 if (!metadataOp) 2787 return op->emitOpError() 2788 << "expected '" << attr << "' to reference a metadata op"; 2789 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2790 Operation *accessGroupOp = 2791 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2792 if (!accessGroupOp) 2793 return op->emitOpError() 2794 << "expected '" << attr << "' to reference an access_group op"; 2795 } 2796 } 2797 2798 Optional<NamedAttribute> loopOptions = 2799 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2800 if (loopOptions && !loopOptions->getValue().isa<LoopOptionsAttr>()) 2801 return op->emitOpError() 2802 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2803 << "' to be a `loopopts` attribute"; 2804 } 2805 2806 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2807 return op->emitOpError() 2808 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2809 << "' is permitted only in argument or result attributes"; 2810 } 2811 2812 // If the data layout attribute is present, it must use the LLVM data layout 2813 // syntax. Try parsing it and report errors in case of failure. Users of this 2814 // attribute may assume it is well-formed and can pass it to the (asserting) 2815 // llvm::DataLayout constructor. 2816 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2817 return success(); 2818 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2819 return verifyDataLayoutString( 2820 stringAttr.getValue(), 2821 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2822 2823 return op->emitOpError() << "expected '" 2824 << LLVM::LLVMDialect::getDataLayoutAttrName() 2825 << "' to be a string attributes"; 2826 } 2827 2828 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2829 Type annotatedType) { 2830 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2831 if (!structType) { 2832 const auto emitIncorrectAnnotatedType = [&op]() { 2833 return op->emitError() 2834 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2835 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2836 }; 2837 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2838 if (!ptrType) 2839 return emitIncorrectAnnotatedType(); 2840 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2841 if (!structType) 2842 return emitIncorrectAnnotatedType(); 2843 } 2844 2845 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2846 if (!arrAttrs) 2847 return op->emitError() << "expected '" 2848 << LLVMDialect::getStructAttrsAttrName() 2849 << "' to be an array attribute"; 2850 2851 if (structType.getBody().size() != arrAttrs.size()) 2852 return op->emitError() 2853 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2854 << "' must match the size of the annotated '!llvm.struct'"; 2855 return success(); 2856 } 2857 2858 static LogicalResult verifyFuncOpInterfaceStructAttr( 2859 Operation *op, Attribute attr, 2860 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) { 2861 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2862 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2863 return op->emitError() << "expected '" 2864 << LLVMDialect::getStructAttrsAttrName() 2865 << "' to be used on function-like operations"; 2866 } 2867 2868 /// Verify LLVMIR function argument attributes. 2869 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2870 unsigned regionIdx, 2871 unsigned argIdx, 2872 NamedAttribute argAttr) { 2873 // Check that llvm.noalias is a unit attribute. 2874 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2875 !argAttr.getValue().isa<UnitAttr>()) 2876 return op->emitError() 2877 << "expected llvm.noalias argument attribute to be a unit attribute"; 2878 // Check that llvm.align is an integer attribute. 2879 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2880 !argAttr.getValue().isa<IntegerAttr>()) 2881 return op->emitError() 2882 << "llvm.align argument attribute of non integer type"; 2883 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2884 return verifyFuncOpInterfaceStructAttr( 2885 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2886 return funcOp.getArgumentTypes()[argIdx]; 2887 }); 2888 } 2889 return success(); 2890 } 2891 2892 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2893 unsigned regionIdx, 2894 unsigned resIdx, 2895 NamedAttribute resAttr) { 2896 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2897 return verifyFuncOpInterfaceStructAttr( 2898 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2899 return funcOp.getResultTypes()[resIdx]; 2900 }); 2901 } 2902 return success(); 2903 } 2904 2905 //===----------------------------------------------------------------------===// 2906 // Utility functions. 2907 //===----------------------------------------------------------------------===// 2908 2909 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2910 StringRef name, StringRef value, 2911 LLVM::Linkage linkage) { 2912 assert(builder.getInsertionBlock() && 2913 builder.getInsertionBlock()->getParentOp() && 2914 "expected builder to point to a block constrained in an op"); 2915 auto module = 2916 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2917 assert(module && "builder points to an op outside of a module"); 2918 2919 // Create the global at the entry of the module. 2920 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2921 MLIRContext *ctx = builder.getContext(); 2922 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2923 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2924 loc, type, /*isConstant=*/true, linkage, name, 2925 builder.getStringAttr(value), /*alignment=*/0); 2926 2927 // Get the pointer to the first character in the global string. 2928 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2929 Value cst0 = builder.create<LLVM::ConstantOp>( 2930 loc, IntegerType::get(ctx, 64), 2931 builder.getIntegerAttr(builder.getIndexType(), 0)); 2932 return builder.create<LLVM::GEPOp>( 2933 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2934 ValueRange{cst0, cst0}); 2935 } 2936 2937 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2938 return op->hasTrait<OpTrait::SymbolTable>() && 2939 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2940 } 2941 2942 void FMFAttr::print(AsmPrinter &printer) const { 2943 printer << "<"; 2944 printer << stringifyFastmathFlags(this->getFlags()); 2945 printer << ">"; 2946 } 2947 2948 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2949 if (failed(parser.parseLess())) 2950 return {}; 2951 2952 FastmathFlags flags = {}; 2953 if (failed(parser.parseOptionalGreater())) { 2954 auto parseFlags = [&]() -> ParseResult { 2955 StringRef elemName; 2956 if (failed(parser.parseKeyword(&elemName))) 2957 return failure(); 2958 2959 auto elem = symbolizeFastmathFlags(elemName); 2960 if (!elem) 2961 return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2962 << elemName; 2963 2964 flags = flags | *elem; 2965 return success(); 2966 }; 2967 if (failed(parser.parseCommaSeparatedList(parseFlags)) || 2968 parser.parseGreater()) 2969 return {}; 2970 } 2971 2972 return FMFAttr::get(parser.getContext(), flags); 2973 } 2974 2975 void LinkageAttr::print(AsmPrinter &printer) const { 2976 printer << "<"; 2977 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2978 printer << stringifyEnum(getLinkage()); 2979 else 2980 printer << static_cast<uint64_t>(getLinkage()); 2981 printer << ">"; 2982 } 2983 2984 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2985 StringRef elemName; 2986 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2987 parser.parseGreater()) 2988 return {}; 2989 auto elem = linkage::symbolizeLinkage(elemName); 2990 if (!elem) { 2991 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2992 return {}; 2993 } 2994 Linkage linkage = *elem; 2995 return LinkageAttr::get(parser.getContext(), linkage); 2996 } 2997 2998 void CConvAttr::print(AsmPrinter &printer) const { 2999 printer << "<"; 3000 if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv()) 3001 printer << stringifyEnum(getCallingConv()); 3002 else 3003 printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv()); 3004 printer << ">"; 3005 } 3006 3007 Attribute CConvAttr::parse(AsmParser &parser, Type type) { 3008 StringRef convName; 3009 3010 if (parser.parseLess() || parser.parseKeyword(&convName) || 3011 parser.parseGreater()) 3012 return {}; 3013 auto cconv = cconv::symbolizeCConv(convName); 3014 if (!cconv) { 3015 parser.emitError(parser.getNameLoc(), "unknown calling convention: ") 3016 << convName; 3017 return {}; 3018 } 3019 CConv cconvVal = *cconv; 3020 return CConvAttr::get(parser.getContext(), cconvVal); 3021 } 3022 3023 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 3024 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 3025 3026 template <typename T> 3027 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 3028 Optional<T> value) { 3029 auto option = llvm::find_if( 3030 options, [tag](auto option) { return option.first == tag; }); 3031 if (option != options.end()) { 3032 if (value) 3033 option->second = *value; 3034 else 3035 options.erase(option); 3036 } else { 3037 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 3038 } 3039 return *this; 3040 } 3041 3042 LoopOptionsAttrBuilder & 3043 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 3044 return setOption(LoopOptionCase::disable_licm, value); 3045 } 3046 3047 /// Set the `interleave_count` option to the provided value. If no value 3048 /// is provided the option is deleted. 3049 LoopOptionsAttrBuilder & 3050 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 3051 return setOption(LoopOptionCase::interleave_count, count); 3052 } 3053 3054 /// Set the `disable_unroll` option to the provided value. If no value 3055 /// is provided the option is deleted. 3056 LoopOptionsAttrBuilder & 3057 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 3058 return setOption(LoopOptionCase::disable_unroll, value); 3059 } 3060 3061 /// Set the `disable_pipeline` option to the provided value. If no value 3062 /// is provided the option is deleted. 3063 LoopOptionsAttrBuilder & 3064 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 3065 return setOption(LoopOptionCase::disable_pipeline, value); 3066 } 3067 3068 /// Set the `pipeline_initiation_interval` option to the provided value. 3069 /// If no value is provided the option is deleted. 3070 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 3071 Optional<uint64_t> count) { 3072 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 3073 } 3074 3075 template <typename T> 3076 static Optional<T> 3077 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 3078 LoopOptionCase option) { 3079 auto it = 3080 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 3081 return optionPair.first < option; 3082 }); 3083 if (it == options.end()) 3084 return {}; 3085 return static_cast<T>(it->second); 3086 } 3087 3088 Optional<bool> LoopOptionsAttr::disableUnroll() { 3089 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 3090 } 3091 3092 Optional<bool> LoopOptionsAttr::disableLICM() { 3093 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 3094 } 3095 3096 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 3097 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 3098 } 3099 3100 /// Build the LoopOptions Attribute from a sorted array of individual options. 3101 LoopOptionsAttr LoopOptionsAttr::get( 3102 MLIRContext *context, 3103 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 3104 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 3105 "LoopOptionsAttr ctor expects a sorted options array"); 3106 return Base::get(context, sortedOptions); 3107 } 3108 3109 /// Build the LoopOptions Attribute from a sorted array of individual options. 3110 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 3111 LoopOptionsAttrBuilder &optionBuilders) { 3112 llvm::sort(optionBuilders.options, llvm::less_first()); 3113 return Base::get(context, optionBuilders.options); 3114 } 3115 3116 void LoopOptionsAttr::print(AsmPrinter &printer) const { 3117 printer << "<"; 3118 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 3119 printer << stringifyEnum(option.first) << " = "; 3120 switch (option.first) { 3121 case LoopOptionCase::disable_licm: 3122 case LoopOptionCase::disable_unroll: 3123 case LoopOptionCase::disable_pipeline: 3124 printer << (option.second ? "true" : "false"); 3125 break; 3126 case LoopOptionCase::interleave_count: 3127 case LoopOptionCase::pipeline_initiation_interval: 3128 printer << option.second; 3129 break; 3130 } 3131 }); 3132 printer << ">"; 3133 } 3134 3135 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 3136 if (failed(parser.parseLess())) 3137 return {}; 3138 3139 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 3140 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 3141 auto parseLoopOptions = [&]() -> ParseResult { 3142 StringRef optionName; 3143 if (parser.parseKeyword(&optionName)) 3144 return failure(); 3145 3146 auto option = symbolizeLoopOptionCase(optionName); 3147 if (!option) 3148 return parser.emitError(parser.getNameLoc(), "unknown loop option: ") 3149 << optionName; 3150 if (!seenOptions.insert(*option).second) 3151 return parser.emitError(parser.getNameLoc(), "loop option present twice"); 3152 if (failed(parser.parseEqual())) 3153 return failure(); 3154 3155 int64_t value; 3156 switch (*option) { 3157 case LoopOptionCase::disable_licm: 3158 case LoopOptionCase::disable_unroll: 3159 case LoopOptionCase::disable_pipeline: 3160 if (succeeded(parser.parseOptionalKeyword("true"))) 3161 value = 1; 3162 else if (succeeded(parser.parseOptionalKeyword("false"))) 3163 value = 0; 3164 else { 3165 return parser.emitError(parser.getNameLoc(), 3166 "expected boolean value 'true' or 'false'"); 3167 } 3168 break; 3169 case LoopOptionCase::interleave_count: 3170 case LoopOptionCase::pipeline_initiation_interval: 3171 if (failed(parser.parseInteger(value))) 3172 return parser.emitError(parser.getNameLoc(), "expected integer value"); 3173 break; 3174 } 3175 options.push_back(std::make_pair(*option, value)); 3176 return success(); 3177 }; 3178 if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater()) 3179 return {}; 3180 3181 llvm::sort(options, llvm::less_first()); 3182 return get(parser.getContext(), options); 3183 } 3184