1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Error.h" 33 #include "llvm/Support/Mutex.h" 34 #include "llvm/Support/SourceMgr.h" 35 36 #include <numeric> 37 38 using namespace mlir; 39 using namespace mlir::LLVM; 40 using mlir::LLVM::cconv::getMaxEnumValForCConv; 41 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 42 43 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 44 45 static constexpr const char kVolatileAttrName[] = "volatile_"; 46 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 47 static constexpr const char kElemTypeAttrName[] = "elem_type"; 48 49 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 50 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 51 #define GET_ATTRDEF_CLASSES 52 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 53 54 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 55 SmallVector<NamedAttribute, 8> filteredAttrs( 56 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 57 if (attr.getName() == "fastmathFlags") { 58 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 59 return defAttr != attr.getValue(); 60 } 61 return true; 62 })); 63 return filteredAttrs; 64 } 65 66 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 67 NamedAttrList &result) { 68 return parser.parseOptionalAttrDict(result); 69 } 70 71 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 72 DictionaryAttr attrs) { 73 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 74 } 75 76 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 77 /// fully defined llvm.func. 78 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 79 Operation *op, 80 SymbolTableCollection &symbolTable) { 81 StringRef name = symbol.getValue(); 82 auto func = 83 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 84 if (!func) 85 return op->emitOpError("'") 86 << name << "' does not reference a valid LLVM function"; 87 if (func.isExternal()) 88 return op->emitOpError("'") << name << "' does not have a definition"; 89 return success(); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // Printing, parsing and builder for LLVM::CmpOp. 94 //===----------------------------------------------------------------------===// 95 96 void ICmpOp::build(OpBuilder &builder, OperationState &result, 97 ICmpPredicate predicate, Value lhs, Value rhs) { 98 auto boolType = IntegerType::get(lhs.getType().getContext(), 1); 99 if (LLVM::isCompatibleVectorType(lhs.getType()) || 100 LLVM::isCompatibleVectorType(rhs.getType())) { 101 int64_t numLHSElements = 1, numRHSElements = 1; 102 if (LLVM::isCompatibleVectorType(lhs.getType())) 103 numLHSElements = 104 LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); 105 if (LLVM::isCompatibleVectorType(rhs.getType())) 106 numRHSElements = 107 LLVM::getVectorNumElements(rhs.getType()).getFixedValue(); 108 build(builder, result, 109 VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType), 110 predicate, lhs, rhs); 111 } else { 112 build(builder, result, boolType, predicate, lhs, rhs); 113 } 114 } 115 116 void ICmpOp::print(OpAsmPrinter &p) { 117 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 118 << ", " << getOperand(1); 119 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 120 p << " : " << getLhs().getType(); 121 } 122 123 void FCmpOp::print(OpAsmPrinter &p) { 124 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 125 << ", " << getOperand(1); 126 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 127 p << " : " << getLhs().getType(); 128 } 129 130 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 131 // attribute-dict? `:` type 132 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 133 // attribute-dict? `:` type 134 template <typename CmpPredicateType> 135 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 136 Builder &builder = parser.getBuilder(); 137 138 StringAttr predicateAttr; 139 OpAsmParser::UnresolvedOperand lhs, rhs; 140 Type type; 141 SMLoc predicateLoc, trailingTypeLoc; 142 if (parser.getCurrentLocation(&predicateLoc) || 143 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 144 parser.parseOperand(lhs) || parser.parseComma() || 145 parser.parseOperand(rhs) || 146 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 147 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 148 parser.resolveOperand(lhs, type, result.operands) || 149 parser.resolveOperand(rhs, type, result.operands)) 150 return failure(); 151 152 // Replace the string attribute `predicate` with an integer attribute. 153 int64_t predicateValue = 0; 154 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 155 Optional<ICmpPredicate> predicate = 156 symbolizeICmpPredicate(predicateAttr.getValue()); 157 if (!predicate) 158 return parser.emitError(predicateLoc) 159 << "'" << predicateAttr.getValue() 160 << "' is an incorrect value of the 'predicate' attribute"; 161 predicateValue = static_cast<int64_t>(predicate.getValue()); 162 } else { 163 Optional<FCmpPredicate> predicate = 164 symbolizeFCmpPredicate(predicateAttr.getValue()); 165 if (!predicate) 166 return parser.emitError(predicateLoc) 167 << "'" << predicateAttr.getValue() 168 << "' is an incorrect value of the 'predicate' attribute"; 169 predicateValue = static_cast<int64_t>(predicate.getValue()); 170 } 171 172 result.attributes.set("predicate", 173 parser.getBuilder().getI64IntegerAttr(predicateValue)); 174 175 // The result type is either i1 or a vector type <? x i1> if the inputs are 176 // vectors. 177 Type resultType = IntegerType::get(builder.getContext(), 1); 178 if (!isCompatibleType(type)) 179 return parser.emitError(trailingTypeLoc, 180 "expected LLVM dialect-compatible type"); 181 if (LLVM::isCompatibleVectorType(type)) { 182 if (LLVM::isScalableVectorType(type)) { 183 resultType = LLVM::getVectorType( 184 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 185 /*isScalable=*/true); 186 } else { 187 resultType = LLVM::getVectorType( 188 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 189 /*isScalable=*/false); 190 } 191 } 192 193 result.addTypes({resultType}); 194 return success(); 195 } 196 197 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 198 return parseCmpOp<ICmpPredicate>(parser, result); 199 } 200 201 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 202 return parseCmpOp<FCmpPredicate>(parser, result); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // Printing, parsing and verification for LLVM::AllocaOp. 207 //===----------------------------------------------------------------------===// 208 209 void AllocaOp::print(OpAsmPrinter &p) { 210 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 211 if (!elemTy) 212 elemTy = *getElemType(); 213 214 auto funcTy = 215 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 216 217 p << ' ' << getArraySize() << " x " << elemTy; 218 if (getAlignment().hasValue() && *getAlignment() != 0) 219 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName}); 220 else 221 p.printOptionalAttrDict((*this)->getAttrs(), 222 {"alignment", kElemTypeAttrName}); 223 p << " : " << funcTy; 224 } 225 226 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 227 // `:` type `,` type 228 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 229 OpAsmParser::UnresolvedOperand arraySize; 230 Type type, elemType; 231 SMLoc trailingTypeLoc; 232 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 233 parser.parseType(elemType) || 234 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 235 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 236 return failure(); 237 238 Optional<NamedAttribute> alignmentAttr = 239 result.attributes.getNamed("alignment"); 240 if (alignmentAttr.hasValue()) { 241 auto alignmentInt = 242 alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>(); 243 if (!alignmentInt) 244 return parser.emitError(parser.getNameLoc(), 245 "expected integer alignment"); 246 if (alignmentInt.getValue().isNullValue()) 247 result.attributes.erase("alignment"); 248 } 249 250 // Extract the result type from the trailing function type. 251 auto funcType = type.dyn_cast<FunctionType>(); 252 if (!funcType || funcType.getNumInputs() != 1 || 253 funcType.getNumResults() != 1) 254 return parser.emitError( 255 trailingTypeLoc, 256 "expected trailing function type with one argument and one result"); 257 258 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 259 return failure(); 260 261 Type resultType = funcType.getResult(0); 262 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) { 263 if (ptrResultType.isOpaque()) 264 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 265 } 266 267 result.addTypes({funcType.getResult(0)}); 268 return success(); 269 } 270 271 /// Checks that the elemental type is present in either the pointer type or 272 /// the attribute, but not both. 273 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, 274 Optional<Type> ptrElementType) { 275 if (ptrType.isOpaque() && !ptrElementType.hasValue()) { 276 return op->emitOpError() << "expected '" << kElemTypeAttrName 277 << "' attribute if opaque pointer type is used"; 278 } 279 if (!ptrType.isOpaque() && ptrElementType.hasValue()) { 280 return op->emitOpError() 281 << "unexpected '" << kElemTypeAttrName 282 << "' attribute when non-opaque pointer type is used"; 283 } 284 return success(); 285 } 286 287 LogicalResult AllocaOp::verify() { 288 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(), 289 getElemType()); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // LLVM::BrOp 294 //===----------------------------------------------------------------------===// 295 296 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 297 assert(index == 0 && "invalid successor index"); 298 return SuccessorOperands(getDestOperandsMutable()); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // LLVM::CondBrOp 303 //===----------------------------------------------------------------------===// 304 305 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 306 assert(index < getNumSuccessors() && "invalid successor index"); 307 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 308 : getFalseDestOperandsMutable()); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // LLVM::SwitchOp 313 //===----------------------------------------------------------------------===// 314 315 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 316 Block *defaultDestination, ValueRange defaultOperands, 317 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 318 ArrayRef<ValueRange> caseOperands, 319 ArrayRef<int32_t> branchWeights) { 320 ElementsAttr caseValuesAttr; 321 if (!caseValues.empty()) 322 caseValuesAttr = builder.getI32VectorAttr(caseValues); 323 324 ElementsAttr weightsAttr; 325 if (!branchWeights.empty()) 326 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 327 328 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 329 weightsAttr, defaultDestination, caseDestinations); 330 } 331 332 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 333 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 334 static ParseResult parseSwitchOpCases( 335 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 336 SmallVectorImpl<Block *> &caseDestinations, 337 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 338 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 339 SmallVector<APInt> values; 340 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 341 do { 342 int64_t value = 0; 343 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 344 if (values.empty() && !integerParseResult.hasValue()) 345 return success(); 346 347 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 348 return failure(); 349 values.push_back(APInt(bitWidth, value)); 350 351 Block *destination; 352 SmallVector<OpAsmParser::UnresolvedOperand> operands; 353 SmallVector<Type> operandTypes; 354 if (parser.parseColon() || parser.parseSuccessor(destination)) 355 return failure(); 356 if (!parser.parseOptionalLParen()) { 357 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 358 /*allowResultNumber=*/false) || 359 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 360 return failure(); 361 } 362 caseDestinations.push_back(destination); 363 caseOperands.emplace_back(operands); 364 caseOperandTypes.emplace_back(operandTypes); 365 } while (!parser.parseOptionalComma()); 366 367 ShapedType caseValueType = 368 VectorType::get(static_cast<int64_t>(values.size()), flagType); 369 caseValues = DenseIntElementsAttr::get(caseValueType, values); 370 return success(); 371 } 372 373 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 374 ElementsAttr caseValues, 375 SuccessorRange caseDestinations, 376 OperandRangeRange caseOperands, 377 const TypeRangeRange &caseOperandTypes) { 378 if (!caseValues) 379 return; 380 381 size_t index = 0; 382 llvm::interleave( 383 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 384 [&](auto i) { 385 p << " "; 386 p << std::get<0>(i).getLimitedValue(); 387 p << ": "; 388 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 389 }, 390 [&] { 391 p << ','; 392 p.printNewline(); 393 }); 394 p.printNewline(); 395 } 396 397 LogicalResult SwitchOp::verify() { 398 if ((!getCaseValues() && !getCaseDestinations().empty()) || 399 (getCaseValues() && 400 getCaseValues()->size() != 401 static_cast<int64_t>(getCaseDestinations().size()))) 402 return emitOpError("expects number of case values to match number of " 403 "case destinations"); 404 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 405 return emitError("expects number of branch weights to match number of " 406 "successors: ") 407 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 408 return success(); 409 } 410 411 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 412 assert(index < getNumSuccessors() && "invalid successor index"); 413 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 414 : getCaseOperandsMutable(index - 1)); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // Code for LLVM::GEPOp. 419 //===----------------------------------------------------------------------===// 420 421 constexpr int GEPOp::kDynamicIndex; 422 423 namespace { 424 /// Base class for llvm::Error related to GEP index. 425 class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> { 426 protected: 427 unsigned indexPos; 428 429 public: 430 static char ID; 431 432 std::error_code convertToErrorCode() const override { 433 return llvm::inconvertibleErrorCode(); 434 } 435 436 explicit GEPIndexError(unsigned pos) : indexPos(pos) {} 437 }; 438 439 /// llvm::Error for out-of-bound GEP index. 440 struct GEPIndexOutOfBoundError 441 : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> { 442 static char ID; 443 444 using ErrorInfo::ErrorInfo; 445 446 void log(llvm::raw_ostream &os) const override { 447 os << "index " << indexPos << " indexing a struct is out of bounds"; 448 } 449 }; 450 451 /// llvm::Error for non-static GEP index indexing a struct. 452 struct GEPStaticIndexError 453 : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> { 454 static char ID; 455 456 using ErrorInfo::ErrorInfo; 457 458 void log(llvm::raw_ostream &os) const override { 459 os << "expected index " << indexPos << " indexing a struct " 460 << "to be constant"; 461 } 462 }; 463 } // end anonymous namespace 464 465 char GEPIndexError::ID = 0; 466 char GEPIndexOutOfBoundError::ID = 0; 467 char GEPStaticIndexError::ID = 0; 468 469 /// For the given `structIndices` and `indices`, check if they're complied 470 /// with `baseGEPType`, especially check against LLVMStructTypes nested within, 471 /// and refine/promote struct index from `indices` to `updatedStructIndices` 472 /// if the latter argument is not null. 473 static llvm::Error 474 recordStructIndices(Type baseGEPType, unsigned indexPos, 475 ArrayRef<int32_t> structIndices, ValueRange indices, 476 SmallVectorImpl<int32_t> *updatedStructIndices, 477 SmallVectorImpl<Value> *remainingIndices) { 478 if (indexPos >= structIndices.size()) 479 // Stop searching 480 return llvm::Error::success(); 481 482 int32_t gepIndex = structIndices[indexPos]; 483 bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; 484 485 unsigned dynamicIndexPos = indexPos; 486 if (!isStaticIndex) 487 dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), 488 LLVM::GEPOp::kDynamicIndex) - 1; 489 490 return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType) 491 .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error { 492 // We don't always want to refine the index (e.g. when performing 493 // verification), so we only refine when updatedStructIndices is not 494 // null. 495 if (!isStaticIndex && updatedStructIndices) { 496 // Try to refine. 497 APInt staticIndexValue; 498 isStaticIndex = matchPattern(indices[dynamicIndexPos], 499 m_ConstantInt(&staticIndexValue)); 500 if (isStaticIndex) { 501 assert(staticIndexValue.getBitWidth() <= 64 && 502 llvm::isInt<32>(staticIndexValue.getLimitedValue()) && 503 "struct index can't fit within int32_t"); 504 gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 505 } 506 } 507 if (!isStaticIndex) 508 return llvm::make_error<GEPStaticIndexError>(indexPos); 509 510 ArrayRef<Type> elementTypes = structType.getBody(); 511 if (gepIndex < 0 || 512 static_cast<size_t>(gepIndex) >= elementTypes.size()) 513 return llvm::make_error<GEPIndexOutOfBoundError>(indexPos); 514 515 if (updatedStructIndices) 516 (*updatedStructIndices)[indexPos] = gepIndex; 517 518 // Instead of recusively going into every children types, we only 519 // dive into the one indexed by gepIndex. 520 return recordStructIndices(elementTypes[gepIndex], indexPos + 1, 521 structIndices, indices, updatedStructIndices, 522 remainingIndices); 523 }) 524 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 525 LLVMArrayType>([&](auto containerType) -> llvm::Error { 526 // Currently we don't refine non-struct index even if it's static. 527 if (remainingIndices) 528 remainingIndices->push_back(indices[dynamicIndexPos]); 529 return recordStructIndices(containerType.getElementType(), indexPos + 1, 530 structIndices, indices, updatedStructIndices, 531 remainingIndices); 532 }) 533 .Default( 534 [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); 535 } 536 537 /// Driver function around `recordStructIndices`. Note that we always check 538 /// from the second GEP index since the first one is always dynamic. 539 static llvm::Error 540 findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices, 541 ValueRange indices, 542 SmallVectorImpl<int32_t> *updatedStructIndices = nullptr, 543 SmallVectorImpl<Value> *remainingIndices = nullptr) { 544 if (remainingIndices) 545 // The first GEP index is always dynamic. 546 remainingIndices->push_back(indices[0]); 547 return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, 548 indices, updatedStructIndices, remainingIndices); 549 } 550 551 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 552 Value basePtr, ValueRange operands, 553 ArrayRef<NamedAttribute> attributes) { 554 build(builder, result, resultType, basePtr, operands, 555 SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes); 556 } 557 558 /// Returns the elemental type of any LLVM-compatible vector type or self. 559 static Type extractVectorElementType(Type type) { 560 if (auto vectorType = type.dyn_cast<VectorType>()) 561 return vectorType.getElementType(); 562 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 563 return scalableVectorType.getElementType(); 564 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 565 return fixedVectorType.getElementType(); 566 return type; 567 } 568 569 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 570 Value basePtr, ValueRange indices, 571 ArrayRef<int32_t> structIndices, 572 ArrayRef<NamedAttribute> attributes) { 573 auto ptrType = 574 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>(); 575 assert(!ptrType.isOpaque() && 576 "expected non-opaque pointer, provide elementType explicitly when " 577 "opaque pointers are used"); 578 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, 579 structIndices, attributes); 580 } 581 582 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 583 Type elementType, Value basePtr, ValueRange indices, 584 ArrayRef<int32_t> structIndices, 585 ArrayRef<NamedAttribute> attributes) { 586 SmallVector<Value> remainingIndices; 587 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 588 structIndices.end()); 589 if (llvm::Error err = 590 findStructIndices(elementType, structIndices, indices, 591 &updatedStructIndices, &remainingIndices)) 592 llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); 593 594 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 595 updatedStructIndices, kDynamicIndex)) && 596 "expected as many index operands as dynamic index attr elements"); 597 598 result.addTypes(resultType); 599 result.addAttributes(attributes); 600 result.addAttribute("structIndices", 601 builder.getI32TensorAttr(updatedStructIndices)); 602 if (extractVectorElementType(basePtr.getType()) 603 .cast<LLVMPointerType>() 604 .isOpaque()) 605 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 606 result.addOperands(basePtr); 607 result.addOperands(remainingIndices); 608 } 609 610 static ParseResult 611 parseGEPIndices(OpAsmParser &parser, 612 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 613 DenseIntElementsAttr &structIndices) { 614 SmallVector<int32_t> constantIndices; 615 616 auto idxParser = [&]() -> ParseResult { 617 int32_t constantIndex; 618 OptionalParseResult parsedInteger = 619 parser.parseOptionalInteger(constantIndex); 620 if (parsedInteger.hasValue()) { 621 if (failed(parsedInteger.getValue())) 622 return failure(); 623 constantIndices.push_back(constantIndex); 624 return success(); 625 } 626 627 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 628 return parser.parseOperand(indices.emplace_back()); 629 }; 630 if (parser.parseCommaSeparatedList(idxParser)) 631 return failure(); 632 633 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 634 return success(); 635 } 636 637 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 638 OperandRange indices, 639 DenseIntElementsAttr structIndices) { 640 unsigned operandIdx = 0; 641 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 642 [&](int32_t cst) { 643 if (cst == LLVM::GEPOp::kDynamicIndex) 644 printer.printOperand(indices[operandIdx++]); 645 else 646 printer << cst; 647 }); 648 } 649 650 LogicalResult LLVM::GEPOp::verify() { 651 if (failed(verifyOpaquePtr( 652 getOperation(), 653 extractVectorElementType(getType()).cast<LLVMPointerType>(), 654 getElemType()))) 655 return failure(); 656 657 auto structIndexRange = getStructIndices().getValues<int32_t>(); 658 // structIndexRange is a kind of iterator, which cannot be converted 659 // to ArrayRef directly. 660 SmallVector<int32_t> structIndices(structIndexRange.size()); 661 for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size())) 662 structIndices[i] = structIndexRange[i]; 663 if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices, 664 getIndices())) 665 return emitOpError() << llvm::toString(std::move(err)); 666 667 return success(); 668 } 669 670 Type LLVM::GEPOp::getSourceElementType() { 671 if (Optional<Type> elemType = getElemType()) 672 return *elemType; 673 674 return extractVectorElementType(getBase().getType()) 675 .cast<LLVMPointerType>() 676 .getElementType(); 677 } 678 679 //===----------------------------------------------------------------------===// 680 // Builder, printer and parser for for LLVM::LoadOp. 681 //===----------------------------------------------------------------------===// 682 683 LogicalResult verifySymbolAttribute( 684 Operation *op, StringRef attributeName, 685 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 686 verifySymbolType) { 687 if (Attribute attribute = op->getAttr(attributeName)) { 688 // The attribute is already verified to be a symbol ref array attribute via 689 // a constraint in the operation definition. 690 for (SymbolRefAttr symbolRef : 691 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 692 StringAttr metadataName = symbolRef.getRootReference(); 693 StringAttr symbolName = symbolRef.getLeafReference(); 694 // We want @metadata::@symbol, not just @symbol 695 if (metadataName == symbolName) { 696 return op->emitOpError() << "expected '" << symbolRef 697 << "' to specify a fully qualified reference"; 698 } 699 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 700 op->getParentOp(), metadataName); 701 if (!metadataOp) 702 return op->emitOpError() 703 << "expected '" << symbolRef << "' to reference a metadata op"; 704 Operation *symbolOp = 705 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 706 if (!symbolOp) 707 return op->emitOpError() 708 << "expected '" << symbolRef << "' to be a valid reference"; 709 if (failed(verifySymbolType(symbolOp, symbolRef))) { 710 return failure(); 711 } 712 } 713 } 714 return success(); 715 } 716 717 // Verifies that metadata ops are wired up properly. 718 template <typename OpTy> 719 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 720 auto verifySymbolType = [op](Operation *symbolOp, 721 SymbolRefAttr symbolRef) -> LogicalResult { 722 if (!isa<OpTy>(symbolOp)) { 723 return op->emitOpError() 724 << "expected '" << symbolRef << "' to resolve to a " 725 << OpTy::getOperationName(); 726 } 727 return success(); 728 }; 729 730 return verifySymbolAttribute(op, attributeName, verifySymbolType); 731 } 732 733 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 734 // access_groups 735 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 736 op, LLVMDialect::getAccessGroupsAttrName()))) 737 return failure(); 738 739 // alias_scopes 740 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 741 op, LLVMDialect::getAliasScopesAttrName()))) 742 return failure(); 743 744 // noalias_scopes 745 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 746 op, LLVMDialect::getNoAliasScopesAttrName()))) 747 return failure(); 748 749 return success(); 750 } 751 752 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 753 754 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 755 Value addr, unsigned alignment, bool isVolatile, 756 bool isNonTemporal) { 757 result.addOperands(addr); 758 result.addTypes(t); 759 if (isVolatile) 760 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 761 if (isNonTemporal) 762 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 763 if (alignment != 0) 764 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 765 } 766 767 void LoadOp::print(OpAsmPrinter &p) { 768 p << ' '; 769 if (getVolatile_()) 770 p << "volatile "; 771 p << getAddr(); 772 p.printOptionalAttrDict((*this)->getAttrs(), 773 {kVolatileAttrName, kElemTypeAttrName}); 774 p << " : " << getAddr().getType(); 775 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 776 p << " -> " << getType(); 777 } 778 779 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 780 // the resulting type if any, null type if opaque pointers are used, and None 781 // if the given type is not the pointer type. 782 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type, 783 SMLoc trailingTypeLoc) { 784 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 785 if (!llvmTy) { 786 parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); 787 return llvm::None; 788 } 789 return llvmTy.getElementType(); 790 } 791 792 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 793 // (`->` type)? 794 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 795 OpAsmParser::UnresolvedOperand addr; 796 Type type; 797 SMLoc trailingTypeLoc; 798 799 if (succeeded(parser.parseOptionalKeyword("volatile"))) 800 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 801 802 if (parser.parseOperand(addr) || 803 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 804 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 805 parser.resolveOperand(addr, type, result.operands)) 806 return failure(); 807 808 Optional<Type> elemTy = 809 getLoadStoreElementType(parser, type, trailingTypeLoc); 810 if (!elemTy) 811 return failure(); 812 if (*elemTy) { 813 result.addTypes(*elemTy); 814 return success(); 815 } 816 817 Type trailingType; 818 if (parser.parseArrow() || parser.parseType(trailingType)) 819 return failure(); 820 result.addTypes(trailingType); 821 return success(); 822 } 823 824 //===----------------------------------------------------------------------===// 825 // Builder, printer and parser for LLVM::StoreOp. 826 //===----------------------------------------------------------------------===// 827 828 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 829 830 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 831 Value addr, unsigned alignment, bool isVolatile, 832 bool isNonTemporal) { 833 result.addOperands({value, addr}); 834 result.addTypes({}); 835 if (isVolatile) 836 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 837 if (isNonTemporal) 838 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 839 if (alignment != 0) 840 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 841 } 842 843 void StoreOp::print(OpAsmPrinter &p) { 844 p << ' '; 845 if (getVolatile_()) 846 p << "volatile "; 847 p << getValue() << ", " << getAddr(); 848 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 849 p << " : "; 850 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 851 p << getValue().getType() << ", "; 852 p << getAddr().getType(); 853 } 854 855 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 856 // attribute-dict? `:` type (`,` type)? 857 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 858 OpAsmParser::UnresolvedOperand addr, value; 859 Type type; 860 SMLoc trailingTypeLoc; 861 862 if (succeeded(parser.parseOptionalKeyword("volatile"))) 863 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 864 865 if (parser.parseOperand(value) || parser.parseComma() || 866 parser.parseOperand(addr) || 867 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 868 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 869 return failure(); 870 871 Type operandType; 872 if (succeeded(parser.parseOptionalComma())) { 873 operandType = type; 874 if (parser.parseType(type)) 875 return failure(); 876 } else { 877 Optional<Type> maybeOperandType = 878 getLoadStoreElementType(parser, type, trailingTypeLoc); 879 if (!maybeOperandType) 880 return failure(); 881 operandType = *maybeOperandType; 882 } 883 884 if (parser.resolveOperand(value, operandType, result.operands) || 885 parser.resolveOperand(addr, type, result.operands)) 886 return failure(); 887 888 return success(); 889 } 890 891 ///===---------------------------------------------------------------------===// 892 /// LLVM::InvokeOp 893 ///===---------------------------------------------------------------------===// 894 895 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 896 assert(index < getNumSuccessors() && "invalid successor index"); 897 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 898 : getUnwindDestOperandsMutable()); 899 } 900 901 LogicalResult InvokeOp::verify() { 902 if (getNumResults() > 1) 903 return emitOpError("must have 0 or 1 result"); 904 905 Block *unwindDest = getUnwindDest(); 906 if (unwindDest->empty()) 907 return emitError("must have at least one operation in unwind destination"); 908 909 // In unwind destination, first operation must be LandingpadOp 910 if (!isa<LandingpadOp>(unwindDest->front())) 911 return emitError("first operation in unwind destination should be a " 912 "llvm.landingpad operation"); 913 914 return success(); 915 } 916 917 void InvokeOp::print(OpAsmPrinter &p) { 918 auto callee = getCallee(); 919 bool isDirect = callee.hasValue(); 920 921 p << ' '; 922 923 // Either function name or pointer 924 if (isDirect) 925 p.printSymbolName(callee.getValue()); 926 else 927 p << getOperand(0); 928 929 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 930 p << " to "; 931 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 932 p << " unwind "; 933 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 934 935 p.printOptionalAttrDict((*this)->getAttrs(), 936 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 937 p << " : "; 938 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 939 getResultTypes()); 940 } 941 942 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 943 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 944 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 945 /// attribute-dict? `:` function-type 946 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 947 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 948 FunctionType funcType; 949 SymbolRefAttr funcAttr; 950 SMLoc trailingTypeLoc; 951 Block *normalDest, *unwindDest; 952 SmallVector<Value, 4> normalOperands, unwindOperands; 953 Builder &builder = parser.getBuilder(); 954 955 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 956 // case of an indirect call, there will be 1 operand before `(`. In case of a 957 // direct call, there will be no operands and the parser will stop at the 958 // function identifier without complaining. 959 if (parser.parseOperandList(operands)) 960 return failure(); 961 bool isDirect = operands.empty(); 962 963 // Optionally parse a function identifier. 964 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 965 return failure(); 966 967 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 968 parser.parseKeyword("to") || 969 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 970 parser.parseKeyword("unwind") || 971 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 972 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 973 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 974 return failure(); 975 976 if (isDirect) { 977 // Make sure types match. 978 if (parser.resolveOperands(operands, funcType.getInputs(), 979 parser.getNameLoc(), result.operands)) 980 return failure(); 981 result.addTypes(funcType.getResults()); 982 } else { 983 // Construct the LLVM IR Dialect function type that the first operand 984 // should match. 985 if (funcType.getNumResults() > 1) 986 return parser.emitError(trailingTypeLoc, 987 "expected function with 0 or 1 result"); 988 989 Type llvmResultType; 990 if (funcType.getNumResults() == 0) { 991 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 992 } else { 993 llvmResultType = funcType.getResult(0); 994 if (!isCompatibleType(llvmResultType)) 995 return parser.emitError(trailingTypeLoc, 996 "expected result to have LLVM type"); 997 } 998 999 SmallVector<Type, 8> argTypes; 1000 argTypes.reserve(funcType.getNumInputs()); 1001 for (Type ty : funcType.getInputs()) { 1002 if (isCompatibleType(ty)) 1003 argTypes.push_back(ty); 1004 else 1005 return parser.emitError(trailingTypeLoc, 1006 "expected LLVM types as inputs"); 1007 } 1008 1009 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1010 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1011 1012 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 1013 1014 // Make sure that the first operand (indirect callee) matches the wrapped 1015 // LLVM IR function type, and that the types of the other call operands 1016 // match the types of the function arguments. 1017 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1018 parser.resolveOperands(funcArguments, funcType.getInputs(), 1019 parser.getNameLoc(), result.operands)) 1020 return failure(); 1021 1022 result.addTypes(llvmResultType); 1023 } 1024 result.addSuccessors({normalDest, unwindDest}); 1025 result.addOperands(normalOperands); 1026 result.addOperands(unwindOperands); 1027 1028 result.addAttribute( 1029 InvokeOp::getOperandSegmentSizeAttr(), 1030 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 1031 static_cast<int32_t>(normalOperands.size()), 1032 static_cast<int32_t>(unwindOperands.size())})); 1033 return success(); 1034 } 1035 1036 ///===----------------------------------------------------------------------===// 1037 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 1038 ///===----------------------------------------------------------------------===// 1039 1040 LogicalResult LandingpadOp::verify() { 1041 Value value; 1042 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 1043 if (!func.getPersonality().hasValue()) 1044 return emitError( 1045 "llvm.landingpad needs to be in a function with a personality"); 1046 } 1047 1048 if (!getCleanup() && getOperands().empty()) 1049 return emitError("landingpad instruction expects at least one clause or " 1050 "cleanup attribute"); 1051 1052 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 1053 value = getOperand(idx); 1054 bool isFilter = value.getType().isa<LLVMArrayType>(); 1055 if (isFilter) { 1056 // FIXME: Verify filter clauses when arrays are appropriately handled 1057 } else { 1058 // catch - global addresses only. 1059 // Bitcast ops should have global addresses as their args. 1060 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 1061 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 1062 continue; 1063 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 1064 << "global addresses expected as operand to " 1065 "bitcast used in clauses for landingpad"; 1066 } 1067 // NullOp and AddressOfOp allowed 1068 if (value.getDefiningOp<NullOp>()) 1069 continue; 1070 if (value.getDefiningOp<AddressOfOp>()) 1071 continue; 1072 return emitError("clause #") 1073 << idx << " is not a known constant - null, addressof, bitcast"; 1074 } 1075 } 1076 return success(); 1077 } 1078 1079 void LandingpadOp::print(OpAsmPrinter &p) { 1080 p << (getCleanup() ? " cleanup " : " "); 1081 1082 // Clauses 1083 for (auto value : getOperands()) { 1084 // Similar to llvm - if clause is an array type then it is filter 1085 // clause else catch clause 1086 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 1087 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1088 << value.getType() << ") "; 1089 } 1090 1091 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1092 1093 p << ": " << getType(); 1094 } 1095 1096 /// <operation> ::= `llvm.landingpad` `cleanup`? 1097 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1098 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1099 // Check for cleanup 1100 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1101 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1102 1103 // Parse clauses with types 1104 while (succeeded(parser.parseOptionalLParen()) && 1105 (succeeded(parser.parseOptionalKeyword("filter")) || 1106 succeeded(parser.parseOptionalKeyword("catch")))) { 1107 OpAsmParser::UnresolvedOperand operand; 1108 Type ty; 1109 if (parser.parseOperand(operand) || parser.parseColon() || 1110 parser.parseType(ty) || 1111 parser.resolveOperand(operand, ty, result.operands) || 1112 parser.parseRParen()) 1113 return failure(); 1114 } 1115 1116 Type type; 1117 if (parser.parseColon() || parser.parseType(type)) 1118 return failure(); 1119 1120 result.addTypes(type); 1121 return success(); 1122 } 1123 1124 //===----------------------------------------------------------------------===// 1125 // Verifying/Printing/parsing for LLVM::CallOp. 1126 //===----------------------------------------------------------------------===// 1127 1128 LogicalResult CallOp::verify() { 1129 if (getNumResults() > 1) 1130 return emitOpError("must have 0 or 1 result"); 1131 1132 // Type for the callee, we'll get it differently depending if it is a direct 1133 // or indirect call. 1134 Type fnType; 1135 1136 bool isIndirect = false; 1137 1138 // If this is an indirect call, the callee attribute is missing. 1139 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1140 if (!calleeName) { 1141 isIndirect = true; 1142 if (!getNumOperands()) 1143 return emitOpError( 1144 "must have either a `callee` attribute or at least an operand"); 1145 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 1146 if (!ptrType) 1147 return emitOpError("indirect call expects a pointer as callee: ") 1148 << ptrType; 1149 fnType = ptrType.getElementType(); 1150 } else { 1151 Operation *callee = 1152 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1153 if (!callee) 1154 return emitOpError() 1155 << "'" << calleeName.getValue() 1156 << "' does not reference a symbol in the current scope"; 1157 auto fn = dyn_cast<LLVMFuncOp>(callee); 1158 if (!fn) 1159 return emitOpError() << "'" << calleeName.getValue() 1160 << "' does not reference a valid LLVM function"; 1161 1162 fnType = fn.getFunctionType(); 1163 } 1164 1165 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1166 if (!funcType) 1167 return emitOpError("callee does not have a functional type: ") << fnType; 1168 1169 // Verify that the operand and result types match the callee. 1170 1171 if (!funcType.isVarArg() && 1172 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1173 return emitOpError() << "incorrect number of operands (" 1174 << (getNumOperands() - isIndirect) 1175 << ") for callee (expecting: " 1176 << funcType.getNumParams() << ")"; 1177 1178 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1179 return emitOpError() << "incorrect number of operands (" 1180 << (getNumOperands() - isIndirect) 1181 << ") for varargs callee (expecting at least: " 1182 << funcType.getNumParams() << ")"; 1183 1184 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1185 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1186 return emitOpError() << "operand type mismatch for operand " << i << ": " 1187 << getOperand(i + isIndirect).getType() 1188 << " != " << funcType.getParamType(i); 1189 1190 if (getNumResults() == 0 && 1191 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1192 return emitOpError() << "expected function call to produce a value"; 1193 1194 if (getNumResults() != 0 && 1195 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1196 return emitOpError() 1197 << "calling function with void result must not produce values"; 1198 1199 if (getNumResults() > 1) 1200 return emitOpError() 1201 << "expected LLVM function call to produce 0 or 1 result"; 1202 1203 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1204 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1205 << " != " << funcType.getReturnType(); 1206 1207 return success(); 1208 } 1209 1210 void CallOp::print(OpAsmPrinter &p) { 1211 auto callee = getCallee(); 1212 bool isDirect = callee.hasValue(); 1213 1214 // Print the direct callee if present as a function attribute, or an indirect 1215 // callee (first operand) otherwise. 1216 p << ' '; 1217 if (isDirect) 1218 p.printSymbolName(callee.getValue()); 1219 else 1220 p << getOperand(0); 1221 1222 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1223 p << '(' << args << ')'; 1224 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1225 1226 // Reconstruct the function MLIR function type from operand and result types. 1227 p << " : "; 1228 p.printFunctionalType(args.getTypes(), getResultTypes()); 1229 } 1230 1231 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1232 // attribute-dict? `:` function-type 1233 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1234 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1235 Type type; 1236 SymbolRefAttr funcAttr; 1237 SMLoc trailingTypeLoc; 1238 1239 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1240 // case of an indirect call, there will be 1 operand before `(`. In case of a 1241 // direct call, there will be no operands and the parser will stop at the 1242 // function identifier without complaining. 1243 if (parser.parseOperandList(operands)) 1244 return failure(); 1245 bool isDirect = operands.empty(); 1246 1247 // Optionally parse a function identifier. 1248 if (isDirect) 1249 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1250 return failure(); 1251 1252 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1253 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1254 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1255 return failure(); 1256 1257 auto funcType = type.dyn_cast<FunctionType>(); 1258 if (!funcType) 1259 return parser.emitError(trailingTypeLoc, "expected function type"); 1260 if (funcType.getNumResults() > 1) 1261 return parser.emitError(trailingTypeLoc, 1262 "expected function with 0 or 1 result"); 1263 if (isDirect) { 1264 // Make sure types match. 1265 if (parser.resolveOperands(operands, funcType.getInputs(), 1266 parser.getNameLoc(), result.operands)) 1267 return failure(); 1268 if (funcType.getNumResults() != 0 && 1269 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1270 result.addTypes(funcType.getResults()); 1271 } else { 1272 Builder &builder = parser.getBuilder(); 1273 Type llvmResultType; 1274 if (funcType.getNumResults() == 0) { 1275 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1276 } else { 1277 llvmResultType = funcType.getResult(0); 1278 if (!isCompatibleType(llvmResultType)) 1279 return parser.emitError(trailingTypeLoc, 1280 "expected result to have LLVM type"); 1281 } 1282 1283 SmallVector<Type, 8> argTypes; 1284 argTypes.reserve(funcType.getNumInputs()); 1285 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1286 auto argType = funcType.getInput(i); 1287 if (!isCompatibleType(argType)) 1288 return parser.emitError(trailingTypeLoc, 1289 "expected LLVM types as inputs"); 1290 argTypes.push_back(argType); 1291 } 1292 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1293 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1294 1295 auto funcArguments = 1296 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1297 1298 // Make sure that the first operand (indirect callee) matches the wrapped 1299 // LLVM IR function type, and that the types of the other call operands 1300 // match the types of the function arguments. 1301 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1302 parser.resolveOperands(funcArguments, funcType.getInputs(), 1303 parser.getNameLoc(), result.operands)) 1304 return failure(); 1305 1306 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1307 result.addTypes(llvmResultType); 1308 } 1309 1310 return success(); 1311 } 1312 1313 //===----------------------------------------------------------------------===// 1314 // Printing/parsing for LLVM::ExtractElementOp. 1315 //===----------------------------------------------------------------------===// 1316 // Expects vector to be of wrapped LLVM vector type and position to be of 1317 // wrapped LLVM i32 type. 1318 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1319 Value vector, Value position, 1320 ArrayRef<NamedAttribute> attrs) { 1321 auto vectorType = vector.getType(); 1322 auto llvmType = LLVM::getVectorElementType(vectorType); 1323 build(b, result, llvmType, vector, position); 1324 result.addAttributes(attrs); 1325 } 1326 1327 void ExtractElementOp::print(OpAsmPrinter &p) { 1328 p << ' ' << getVector() << "[" << getPosition() << " : " 1329 << getPosition().getType() << "]"; 1330 p.printOptionalAttrDict((*this)->getAttrs()); 1331 p << " : " << getVector().getType(); 1332 } 1333 1334 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1335 // attribute-dict? `:` type 1336 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1337 OperationState &result) { 1338 SMLoc loc; 1339 OpAsmParser::UnresolvedOperand vector, position; 1340 Type type, positionType; 1341 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1342 parser.parseLSquare() || parser.parseOperand(position) || 1343 parser.parseColonType(positionType) || parser.parseRSquare() || 1344 parser.parseOptionalAttrDict(result.attributes) || 1345 parser.parseColonType(type) || 1346 parser.resolveOperand(vector, type, result.operands) || 1347 parser.resolveOperand(position, positionType, result.operands)) 1348 return failure(); 1349 if (!LLVM::isCompatibleVectorType(type)) 1350 return parser.emitError( 1351 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1352 result.addTypes(LLVM::getVectorElementType(type)); 1353 return success(); 1354 } 1355 1356 LogicalResult ExtractElementOp::verify() { 1357 Type vectorType = getVector().getType(); 1358 if (!LLVM::isCompatibleVectorType(vectorType)) 1359 return emitOpError("expected LLVM dialect-compatible vector type for " 1360 "operand #1, got") 1361 << vectorType; 1362 Type valueType = LLVM::getVectorElementType(vectorType); 1363 if (valueType != getRes().getType()) 1364 return emitOpError() << "Type mismatch: extracting from " << vectorType 1365 << " should produce " << valueType 1366 << " but this op returns " << getRes().getType(); 1367 return success(); 1368 } 1369 1370 //===----------------------------------------------------------------------===// 1371 // Printing/parsing for LLVM::ExtractValueOp. 1372 //===----------------------------------------------------------------------===// 1373 1374 void ExtractValueOp::print(OpAsmPrinter &p) { 1375 p << ' ' << getContainer() << getPosition(); 1376 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1377 p << " : " << getContainer().getType(); 1378 } 1379 1380 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1381 // `containerType`. Position is an integer array attribute where each value 1382 // is a zero-based position of the element in the aggregate type. Return the 1383 // resulting type wrapped in MLIR, or nullptr on error. 1384 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1385 Type containerType, 1386 ArrayAttr positionAttr, 1387 SMLoc attributeLoc, 1388 SMLoc typeLoc) { 1389 Type llvmType = containerType; 1390 if (!isCompatibleType(containerType)) 1391 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1392 1393 // Infer the element type from the structure type: iteratively step inside the 1394 // type by taking the element type, indexed by the position attribute for 1395 // structures. Check the position index before accessing, it is supposed to 1396 // be in bounds. 1397 for (Attribute subAttr : positionAttr) { 1398 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1399 if (!positionElementAttr) 1400 return parser.emitError(attributeLoc, 1401 "expected an array of integer literals"), 1402 nullptr; 1403 int position = positionElementAttr.getInt(); 1404 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1405 if (position < 0 || 1406 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1407 return parser.emitError(attributeLoc, "position out of bounds"), 1408 nullptr; 1409 llvmType = arrayType.getElementType(); 1410 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1411 if (position < 0 || 1412 static_cast<unsigned>(position) >= structType.getBody().size()) 1413 return parser.emitError(attributeLoc, "position out of bounds"), 1414 nullptr; 1415 llvmType = structType.getBody()[position]; 1416 } else { 1417 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1418 nullptr; 1419 } 1420 } 1421 return llvmType; 1422 } 1423 1424 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1425 // `containerType`. Returns null on failure. 1426 static Type getInsertExtractValueElementType(Type containerType, 1427 ArrayAttr positionAttr, 1428 Operation *op) { 1429 Type llvmType = containerType; 1430 if (!isCompatibleType(containerType)) { 1431 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1432 return {}; 1433 } 1434 1435 // Infer the element type from the structure type: iteratively step inside the 1436 // type by taking the element type, indexed by the position attribute for 1437 // structures. Check the position index before accessing, it is supposed to 1438 // be in bounds. 1439 for (Attribute subAttr : positionAttr) { 1440 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1441 if (!positionElementAttr) { 1442 op->emitOpError("expected an array of integer literals, got: ") 1443 << subAttr; 1444 return {}; 1445 } 1446 int position = positionElementAttr.getInt(); 1447 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1448 if (position < 0 || 1449 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1450 op->emitOpError("position out of bounds: ") << position; 1451 return {}; 1452 } 1453 llvmType = arrayType.getElementType(); 1454 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1455 if (position < 0 || 1456 static_cast<unsigned>(position) >= structType.getBody().size()) { 1457 op->emitOpError("position out of bounds") << position; 1458 return {}; 1459 } 1460 llvmType = structType.getBody()[position]; 1461 } else { 1462 op->emitOpError("expected LLVM IR structure/array type, got: ") 1463 << llvmType; 1464 return {}; 1465 } 1466 } 1467 return llvmType; 1468 } 1469 1470 // <operation> ::= `llvm.extractvalue` ssa-use 1471 // `[` integer-literal (`,` integer-literal)* `]` 1472 // attribute-dict? `:` type 1473 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1474 OpAsmParser::UnresolvedOperand container; 1475 Type containerType; 1476 ArrayAttr positionAttr; 1477 SMLoc attributeLoc, trailingTypeLoc; 1478 1479 if (parser.parseOperand(container) || 1480 parser.getCurrentLocation(&attributeLoc) || 1481 parser.parseAttribute(positionAttr, "position", result.attributes) || 1482 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1483 parser.getCurrentLocation(&trailingTypeLoc) || 1484 parser.parseType(containerType) || 1485 parser.resolveOperand(container, containerType, result.operands)) 1486 return failure(); 1487 1488 auto elementType = getInsertExtractValueElementType( 1489 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1490 if (!elementType) 1491 return failure(); 1492 1493 result.addTypes(elementType); 1494 return success(); 1495 } 1496 1497 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1498 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1499 OpFoldResult result = {}; 1500 while (insertValueOp) { 1501 if (getPosition() == insertValueOp.getPosition()) 1502 return insertValueOp.getValue(); 1503 unsigned min = 1504 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1505 // If one is fully prefix of the other, stop propagating back as it will 1506 // miss dependencies. For instance, %3 should not fold to %f0 in the 1507 // following example: 1508 // ``` 1509 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1510 // !llvm.array<4 x !llvm.array<4xf32>> 1511 // %2 = llvm.insertvalue %arr, %1[0] : 1512 // !llvm.array<4 x !llvm.array<4xf32>> 1513 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1514 // ``` 1515 if (getPosition().getValue().take_front(min) == 1516 insertValueOp.getPosition().getValue().take_front(min)) 1517 return result; 1518 1519 // If neither a prefix, nor the exact position, we can extract out of the 1520 // value being inserted into. Moreover, we can try again if that operand 1521 // is itself an insertvalue expression. 1522 getContainerMutable().assign(insertValueOp.getContainer()); 1523 result = getResult(); 1524 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1525 } 1526 return result; 1527 } 1528 1529 LogicalResult ExtractValueOp::verify() { 1530 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1531 getPositionAttr(), *this); 1532 if (!valueType) 1533 return failure(); 1534 1535 if (getRes().getType() != valueType) 1536 return emitOpError() << "Type mismatch: extracting from " 1537 << getContainer().getType() << " should produce " 1538 << valueType << " but this op returns " 1539 << getRes().getType(); 1540 return success(); 1541 } 1542 1543 //===----------------------------------------------------------------------===// 1544 // Printing/parsing for LLVM::InsertElementOp. 1545 //===----------------------------------------------------------------------===// 1546 1547 void InsertElementOp::print(OpAsmPrinter &p) { 1548 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1549 << getPosition().getType() << "]"; 1550 p.printOptionalAttrDict((*this)->getAttrs()); 1551 p << " : " << getVector().getType(); 1552 } 1553 1554 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1555 // attribute-dict? `:` type 1556 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1557 OperationState &result) { 1558 SMLoc loc; 1559 OpAsmParser::UnresolvedOperand vector, value, position; 1560 Type vectorType, positionType; 1561 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1562 parser.parseComma() || parser.parseOperand(vector) || 1563 parser.parseLSquare() || parser.parseOperand(position) || 1564 parser.parseColonType(positionType) || parser.parseRSquare() || 1565 parser.parseOptionalAttrDict(result.attributes) || 1566 parser.parseColonType(vectorType)) 1567 return failure(); 1568 1569 if (!LLVM::isCompatibleVectorType(vectorType)) 1570 return parser.emitError( 1571 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1572 Type valueType = LLVM::getVectorElementType(vectorType); 1573 if (!valueType) 1574 return failure(); 1575 1576 if (parser.resolveOperand(vector, vectorType, result.operands) || 1577 parser.resolveOperand(value, valueType, result.operands) || 1578 parser.resolveOperand(position, positionType, result.operands)) 1579 return failure(); 1580 1581 result.addTypes(vectorType); 1582 return success(); 1583 } 1584 1585 LogicalResult InsertElementOp::verify() { 1586 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1587 if (valueType != getValue().getType()) 1588 return emitOpError() << "Type mismatch: cannot insert " 1589 << getValue().getType() << " into " 1590 << getVector().getType(); 1591 return success(); 1592 } 1593 1594 //===----------------------------------------------------------------------===// 1595 // Printing/parsing for LLVM::InsertValueOp. 1596 //===----------------------------------------------------------------------===// 1597 1598 void InsertValueOp::print(OpAsmPrinter &p) { 1599 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1600 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1601 p << " : " << getContainer().getType(); 1602 } 1603 1604 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1605 // `[` integer-literal (`,` integer-literal)* `]` 1606 // attribute-dict? `:` type 1607 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1608 OpAsmParser::UnresolvedOperand container, value; 1609 Type containerType; 1610 ArrayAttr positionAttr; 1611 SMLoc attributeLoc, trailingTypeLoc; 1612 1613 if (parser.parseOperand(value) || parser.parseComma() || 1614 parser.parseOperand(container) || 1615 parser.getCurrentLocation(&attributeLoc) || 1616 parser.parseAttribute(positionAttr, "position", result.attributes) || 1617 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1618 parser.getCurrentLocation(&trailingTypeLoc) || 1619 parser.parseType(containerType)) 1620 return failure(); 1621 1622 auto valueType = getInsertExtractValueElementType( 1623 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1624 if (!valueType) 1625 return failure(); 1626 1627 if (parser.resolveOperand(container, containerType, result.operands) || 1628 parser.resolveOperand(value, valueType, result.operands)) 1629 return failure(); 1630 1631 result.addTypes(containerType); 1632 return success(); 1633 } 1634 1635 LogicalResult InsertValueOp::verify() { 1636 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1637 getPositionAttr(), *this); 1638 if (!valueType) 1639 return failure(); 1640 1641 if (getValue().getType() != valueType) 1642 return emitOpError() << "Type mismatch: cannot insert " 1643 << getValue().getType() << " into " 1644 << getContainer().getType(); 1645 1646 return success(); 1647 } 1648 1649 //===----------------------------------------------------------------------===// 1650 // Printing, parsing and verification for LLVM::ReturnOp. 1651 //===----------------------------------------------------------------------===// 1652 1653 LogicalResult ReturnOp::verify() { 1654 if (getNumOperands() > 1) 1655 return emitOpError("expected at most 1 operand"); 1656 1657 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1658 Type expectedType = parent.getFunctionType().getReturnType(); 1659 if (expectedType.isa<LLVMVoidType>()) { 1660 if (getNumOperands() == 0) 1661 return success(); 1662 InFlightDiagnostic diag = emitOpError("expected no operands"); 1663 diag.attachNote(parent->getLoc()) << "when returning from function"; 1664 return diag; 1665 } 1666 if (getNumOperands() == 0) { 1667 if (expectedType.isa<LLVMVoidType>()) 1668 return success(); 1669 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1670 diag.attachNote(parent->getLoc()) << "when returning from function"; 1671 return diag; 1672 } 1673 if (expectedType != getOperand(0).getType()) { 1674 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1675 diag.attachNote(parent->getLoc()) << "when returning from function"; 1676 return diag; 1677 } 1678 } 1679 return success(); 1680 } 1681 1682 //===----------------------------------------------------------------------===// 1683 // ResumeOp 1684 //===----------------------------------------------------------------------===// 1685 1686 LogicalResult ResumeOp::verify() { 1687 if (!getValue().getDefiningOp<LandingpadOp>()) 1688 return emitOpError("expects landingpad value as operand"); 1689 // No check for personality of function - landingpad op verifies it. 1690 return success(); 1691 } 1692 1693 //===----------------------------------------------------------------------===// 1694 // Verifier for LLVM::AddressOfOp. 1695 //===----------------------------------------------------------------------===// 1696 1697 template <typename OpTy> 1698 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1699 Operation *module = parent; 1700 while (module && !satisfiesLLVMModule(module)) 1701 module = module->getParentOp(); 1702 assert(module && "unexpected operation outside of a module"); 1703 return dyn_cast_or_null<OpTy>( 1704 mlir::SymbolTable::lookupSymbolIn(module, name)); 1705 } 1706 1707 GlobalOp AddressOfOp::getGlobal() { 1708 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1709 getGlobalName()); 1710 } 1711 1712 LLVMFuncOp AddressOfOp::getFunction() { 1713 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1714 getGlobalName()); 1715 } 1716 1717 LogicalResult AddressOfOp::verify() { 1718 auto global = getGlobal(); 1719 auto function = getFunction(); 1720 if (!global && !function) 1721 return emitOpError( 1722 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1723 1724 LLVMPointerType type = getType(); 1725 if (global && global.getAddrSpace() != type.getAddressSpace()) 1726 return emitOpError("pointer address space must match address space of the " 1727 "referenced global"); 1728 1729 if (type.isOpaque()) 1730 return success(); 1731 1732 if (global && type.getElementType() != global.getType()) 1733 return emitOpError( 1734 "the type must be a pointer to the type of the referenced global"); 1735 1736 if (function && type.getElementType() != function.getFunctionType()) 1737 return emitOpError( 1738 "the type must be a pointer to the type of the referenced function"); 1739 1740 return success(); 1741 } 1742 1743 //===----------------------------------------------------------------------===// 1744 // Builder, printer and verifier for LLVM::GlobalOp. 1745 //===----------------------------------------------------------------------===// 1746 1747 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1748 bool isConstant, Linkage linkage, StringRef name, 1749 Attribute value, uint64_t alignment, unsigned addrSpace, 1750 bool dsoLocal, bool threadLocal, 1751 ArrayRef<NamedAttribute> attrs) { 1752 result.addAttribute(getSymNameAttrName(result.name), 1753 builder.getStringAttr(name)); 1754 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 1755 if (isConstant) 1756 result.addAttribute(getConstantAttrName(result.name), 1757 builder.getUnitAttr()); 1758 if (value) 1759 result.addAttribute(getValueAttrName(result.name), value); 1760 if (dsoLocal) 1761 result.addAttribute(getDsoLocalAttrName(result.name), 1762 builder.getUnitAttr()); 1763 if (threadLocal) 1764 result.addAttribute(getThreadLocal_AttrName(result.name), 1765 builder.getUnitAttr()); 1766 1767 // Only add an alignment attribute if the "alignment" input 1768 // is different from 0. The value must also be a power of two, but 1769 // this is tested in GlobalOp::verify, not here. 1770 if (alignment != 0) 1771 result.addAttribute(getAlignmentAttrName(result.name), 1772 builder.getI64IntegerAttr(alignment)); 1773 1774 result.addAttribute(getLinkageAttrName(result.name), 1775 LinkageAttr::get(builder.getContext(), linkage)); 1776 if (addrSpace != 0) 1777 result.addAttribute(getAddrSpaceAttrName(result.name), 1778 builder.getI32IntegerAttr(addrSpace)); 1779 result.attributes.append(attrs.begin(), attrs.end()); 1780 result.addRegion(); 1781 } 1782 1783 void GlobalOp::print(OpAsmPrinter &p) { 1784 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1785 if (auto unnamedAddr = getUnnamedAddr()) { 1786 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1787 if (!str.empty()) 1788 p << str << ' '; 1789 } 1790 if (getThreadLocal_()) 1791 p << "thread_local "; 1792 if (getConstant()) 1793 p << "constant "; 1794 p.printSymbolName(getSymName()); 1795 p << '('; 1796 if (auto value = getValueOrNull()) 1797 p.printAttribute(value); 1798 p << ')'; 1799 // Note that the alignment attribute is printed using the 1800 // default syntax here, even though it is an inherent attribute 1801 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1802 p.printOptionalAttrDict( 1803 (*this)->getAttrs(), 1804 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), 1805 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), 1806 getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); 1807 1808 // Print the trailing type unless it's a string global. 1809 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1810 return; 1811 p << " : " << getType(); 1812 1813 Region &initializer = getInitializerRegion(); 1814 if (!initializer.empty()) { 1815 p << ' '; 1816 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1817 } 1818 } 1819 1820 // Parses one of the keywords provided in the list `keywords` and returns the 1821 // position of the parsed keyword in the list. If none of the keywords from the 1822 // list is parsed, returns -1. 1823 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1824 ArrayRef<StringRef> keywords) { 1825 for (const auto &en : llvm::enumerate(keywords)) { 1826 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1827 return en.index(); 1828 } 1829 return -1; 1830 } 1831 1832 namespace { 1833 template <typename Ty> 1834 struct EnumTraits {}; 1835 1836 #define REGISTER_ENUM_TYPE(Ty) \ 1837 template <> \ 1838 struct EnumTraits<Ty> { \ 1839 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1840 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1841 } 1842 1843 REGISTER_ENUM_TYPE(Linkage); 1844 REGISTER_ENUM_TYPE(UnnamedAddr); 1845 REGISTER_ENUM_TYPE(CConv); 1846 } // namespace 1847 1848 /// Parse an enum from the keyword, or default to the provided default value. 1849 /// The return type is the enum type by default, unless overriden with the 1850 /// second template argument. 1851 template <typename EnumTy, typename RetTy = EnumTy> 1852 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1853 OperationState &result, 1854 EnumTy defaultValue) { 1855 SmallVector<StringRef, 10> names; 1856 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1857 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1858 1859 int index = parseOptionalKeywordAlternative(parser, names); 1860 if (index == -1) 1861 return static_cast<RetTy>(defaultValue); 1862 return static_cast<RetTy>(index); 1863 } 1864 1865 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1866 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1867 // align ::= `align` `=` UINT64 1868 // 1869 // The type can be omitted for string attributes, in which case it will be 1870 // inferred from the value of the string as [strlen(value) x i8]. 1871 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1872 MLIRContext *ctx = parser.getContext(); 1873 // Parse optional linkage, default to External. 1874 result.addAttribute(getLinkageAttrName(result.name), 1875 LLVM::LinkageAttr::get( 1876 ctx, parseOptionalLLVMKeyword<Linkage>( 1877 parser, result, LLVM::Linkage::External))); 1878 1879 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 1880 result.addAttribute(getThreadLocal_AttrName(result.name), 1881 parser.getBuilder().getUnitAttr()); 1882 1883 // Parse optional UnnamedAddr, default to None. 1884 result.addAttribute(getUnnamedAddrAttrName(result.name), 1885 parser.getBuilder().getI64IntegerAttr( 1886 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1887 parser, result, LLVM::UnnamedAddr::None))); 1888 1889 if (succeeded(parser.parseOptionalKeyword("constant"))) 1890 result.addAttribute(getConstantAttrName(result.name), 1891 parser.getBuilder().getUnitAttr()); 1892 1893 StringAttr name; 1894 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 1895 result.attributes) || 1896 parser.parseLParen()) 1897 return failure(); 1898 1899 Attribute value; 1900 if (parser.parseOptionalRParen()) { 1901 if (parser.parseAttribute(value, getValueAttrName(result.name), 1902 result.attributes) || 1903 parser.parseRParen()) 1904 return failure(); 1905 } 1906 1907 SmallVector<Type, 1> types; 1908 if (parser.parseOptionalAttrDict(result.attributes) || 1909 parser.parseOptionalColonTypeList(types)) 1910 return failure(); 1911 1912 if (types.size() > 1) 1913 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1914 1915 Region &initRegion = *result.addRegion(); 1916 if (types.empty()) { 1917 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1918 MLIRContext *context = parser.getContext(); 1919 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1920 strAttr.getValue().size()); 1921 types.push_back(arrayType); 1922 } else { 1923 return parser.emitError(parser.getNameLoc(), 1924 "type can only be omitted for string globals"); 1925 } 1926 } else { 1927 OptionalParseResult parseResult = 1928 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1929 /*argTypes=*/{}); 1930 if (parseResult.hasValue() && failed(*parseResult)) 1931 return failure(); 1932 } 1933 1934 result.addAttribute(getGlobalTypeAttrName(result.name), 1935 TypeAttr::get(types[0])); 1936 return success(); 1937 } 1938 1939 static bool isZeroAttribute(Attribute value) { 1940 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1941 return intValue.getValue().isNullValue(); 1942 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1943 return fpValue.getValue().isZero(); 1944 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1945 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1946 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1947 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1948 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1949 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1950 return false; 1951 } 1952 1953 LogicalResult GlobalOp::verify() { 1954 if (!LLVMPointerType::isValidElementType(getType())) 1955 return emitOpError( 1956 "expects type to be a valid element type for an LLVM pointer"); 1957 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1958 return emitOpError("must appear at the module level"); 1959 1960 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1961 auto type = getType().dyn_cast<LLVMArrayType>(); 1962 IntegerType elementType = 1963 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1964 if (!elementType || elementType.getWidth() != 8 || 1965 type.getNumElements() != strAttr.getValue().size()) 1966 return emitOpError( 1967 "requires an i8 array type of the length equal to that of the string " 1968 "attribute"); 1969 } 1970 1971 if (getLinkage() == Linkage::Common) { 1972 if (Attribute value = getValueOrNull()) { 1973 if (!isZeroAttribute(value)) { 1974 return emitOpError() 1975 << "expected zero value for '" 1976 << stringifyLinkage(Linkage::Common) << "' linkage"; 1977 } 1978 } 1979 } 1980 1981 if (getLinkage() == Linkage::Appending) { 1982 if (!getType().isa<LLVMArrayType>()) { 1983 return emitOpError() << "expected array type for '" 1984 << stringifyLinkage(Linkage::Appending) 1985 << "' linkage"; 1986 } 1987 } 1988 1989 Optional<uint64_t> alignAttr = getAlignment(); 1990 if (alignAttr.hasValue()) { 1991 uint64_t value = alignAttr.getValue(); 1992 if (!llvm::isPowerOf2_64(value)) 1993 return emitError() << "alignment attribute is not a power of 2"; 1994 } 1995 1996 return success(); 1997 } 1998 1999 LogicalResult GlobalOp::verifyRegions() { 2000 if (Block *b = getInitializerBlock()) { 2001 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 2002 if (ret.operand_type_begin() == ret.operand_type_end()) 2003 return emitOpError("initializer region cannot return void"); 2004 if (*ret.operand_type_begin() != getType()) 2005 return emitOpError("initializer region type ") 2006 << *ret.operand_type_begin() << " does not match global type " 2007 << getType(); 2008 2009 for (Operation &op : *b) { 2010 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 2011 if (!iface || !iface.hasNoEffect()) 2012 return op.emitError() 2013 << "ops with side effects not allowed in global initializers"; 2014 } 2015 2016 if (getValueOrNull()) 2017 return emitOpError("cannot have both initializer value and region"); 2018 } 2019 2020 return success(); 2021 } 2022 2023 //===----------------------------------------------------------------------===// 2024 // LLVM::GlobalCtorsOp 2025 //===----------------------------------------------------------------------===// 2026 2027 LogicalResult 2028 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2029 for (Attribute ctor : getCtors()) { 2030 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 2031 symbolTable))) 2032 return failure(); 2033 } 2034 return success(); 2035 } 2036 2037 LogicalResult GlobalCtorsOp::verify() { 2038 if (getCtors().size() != getPriorities().size()) 2039 return emitError( 2040 "mismatch between the number of ctors and the number of priorities"); 2041 return success(); 2042 } 2043 2044 //===----------------------------------------------------------------------===// 2045 // LLVM::GlobalDtorsOp 2046 //===----------------------------------------------------------------------===// 2047 2048 LogicalResult 2049 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2050 for (Attribute dtor : getDtors()) { 2051 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 2052 symbolTable))) 2053 return failure(); 2054 } 2055 return success(); 2056 } 2057 2058 LogicalResult GlobalDtorsOp::verify() { 2059 if (getDtors().size() != getPriorities().size()) 2060 return emitError( 2061 "mismatch between the number of dtors and the number of priorities"); 2062 return success(); 2063 } 2064 2065 //===----------------------------------------------------------------------===// 2066 // Printing/parsing for LLVM::ShuffleVectorOp. 2067 //===----------------------------------------------------------------------===// 2068 // Expects vector to be of wrapped LLVM vector type and position to be of 2069 // wrapped LLVM i32 type. 2070 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 2071 Value v1, Value v2, ArrayAttr mask, 2072 ArrayRef<NamedAttribute> attrs) { 2073 auto containerType = v1.getType(); 2074 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), 2075 mask.size(), 2076 LLVM::isScalableVectorType(containerType)); 2077 build(b, result, vType, v1, v2, mask); 2078 result.addAttributes(attrs); 2079 } 2080 2081 void ShuffleVectorOp::print(OpAsmPrinter &p) { 2082 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 2083 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 2084 p << " : " << getV1().getType() << ", " << getV2().getType(); 2085 } 2086 2087 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 2088 // `[` integer-literal (`,` integer-literal)* `]` 2089 // attribute-dict? `:` type 2090 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 2091 OperationState &result) { 2092 SMLoc loc; 2093 OpAsmParser::UnresolvedOperand v1, v2; 2094 ArrayAttr maskAttr; 2095 Type typeV1, typeV2; 2096 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 2097 parser.parseComma() || parser.parseOperand(v2) || 2098 parser.parseAttribute(maskAttr, "mask", result.attributes) || 2099 parser.parseOptionalAttrDict(result.attributes) || 2100 parser.parseColonType(typeV1) || parser.parseComma() || 2101 parser.parseType(typeV2) || 2102 parser.resolveOperand(v1, typeV1, result.operands) || 2103 parser.resolveOperand(v2, typeV2, result.operands)) 2104 return failure(); 2105 if (!LLVM::isCompatibleVectorType(typeV1)) 2106 return parser.emitError( 2107 loc, "expected LLVM IR dialect vector type for operand #1"); 2108 auto vType = 2109 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 2110 LLVM::isScalableVectorType(typeV1)); 2111 result.addTypes(vType); 2112 return success(); 2113 } 2114 2115 LogicalResult ShuffleVectorOp::verify() { 2116 Type type1 = getV1().getType(); 2117 Type type2 = getV2().getType(); 2118 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 2119 return emitOpError("expected matching LLVM IR Dialect element types"); 2120 if (LLVM::isScalableVectorType(type1)) 2121 if (llvm::any_of(getMask(), [](Attribute attr) { 2122 return attr.cast<IntegerAttr>().getInt() != 0; 2123 })) 2124 return emitOpError("expected a splat operation for scalable vectors"); 2125 return success(); 2126 } 2127 2128 //===----------------------------------------------------------------------===// 2129 // Implementations for LLVM::LLVMFuncOp. 2130 //===----------------------------------------------------------------------===// 2131 2132 // Add the entry block to the function. 2133 Block *LLVMFuncOp::addEntryBlock() { 2134 assert(empty() && "function already has an entry block"); 2135 assert(!isVarArg() && "unimplemented: non-external variadic functions"); 2136 2137 auto *entry = new Block; 2138 push_back(entry); 2139 2140 // FIXME: Allow passing in proper locations for the entry arguments. 2141 LLVMFunctionType type = getFunctionType(); 2142 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2143 entry->addArgument(type.getParamType(i), getLoc()); 2144 return entry; 2145 } 2146 2147 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2148 StringRef name, Type type, LLVM::Linkage linkage, 2149 bool dsoLocal, CConv cconv, 2150 ArrayRef<NamedAttribute> attrs, 2151 ArrayRef<DictionaryAttr> argAttrs) { 2152 result.addRegion(); 2153 result.addAttribute(SymbolTable::getSymbolAttrName(), 2154 builder.getStringAttr(name)); 2155 result.addAttribute(getFunctionTypeAttrName(result.name), 2156 TypeAttr::get(type)); 2157 result.addAttribute(getLinkageAttrName(result.name), 2158 LinkageAttr::get(builder.getContext(), linkage)); 2159 result.addAttribute(getCConvAttrName(result.name), 2160 CConvAttr::get(builder.getContext(), cconv)); 2161 result.attributes.append(attrs.begin(), attrs.end()); 2162 if (dsoLocal) 2163 result.addAttribute("dso_local", builder.getUnitAttr()); 2164 if (argAttrs.empty()) 2165 return; 2166 2167 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 2168 "expected as many argument attribute lists as arguments"); 2169 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 2170 /*resultAttrs=*/llvm::None); 2171 } 2172 2173 // Builds an LLVM function type from the given lists of input and output types. 2174 // Returns a null type if any of the types provided are non-LLVM types, or if 2175 // there is more than one output type. 2176 static Type 2177 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2178 ArrayRef<Type> outputs, 2179 function_interface_impl::VariadicFlag variadicFlag) { 2180 Builder &b = parser.getBuilder(); 2181 if (outputs.size() > 1) { 2182 parser.emitError(loc, "failed to construct function type: expected zero or " 2183 "one function result"); 2184 return {}; 2185 } 2186 2187 // Convert inputs to LLVM types, exit early on error. 2188 SmallVector<Type, 4> llvmInputs; 2189 for (auto t : inputs) { 2190 if (!isCompatibleType(t)) { 2191 parser.emitError(loc, "failed to construct function type: expected LLVM " 2192 "type for function arguments"); 2193 return {}; 2194 } 2195 llvmInputs.push_back(t); 2196 } 2197 2198 // No output is denoted as "void" in LLVM type system. 2199 Type llvmOutput = 2200 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2201 if (!isCompatibleType(llvmOutput)) { 2202 parser.emitError(loc, "failed to construct function type: expected LLVM " 2203 "type for function results") 2204 << llvmOutput; 2205 return {}; 2206 } 2207 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2208 variadicFlag.isVariadic()); 2209 } 2210 2211 // Parses an LLVM function. 2212 // 2213 // operation ::= `llvm.func` linkage? cconv? function-signature 2214 // function-attributes? 2215 // function-body 2216 // 2217 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2218 // Default to external linkage if no keyword is provided. 2219 result.addAttribute( 2220 getLinkageAttrName(result.name), 2221 LinkageAttr::get(parser.getContext(), 2222 parseOptionalLLVMKeyword<Linkage>( 2223 parser, result, LLVM::Linkage::External))); 2224 2225 // Default to C Calling Convention if no keyword is provided. 2226 result.addAttribute( 2227 getCConvAttrName(result.name), 2228 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 2229 parser, result, LLVM::CConv::C))); 2230 2231 StringAttr nameAttr; 2232 SmallVector<OpAsmParser::Argument> entryArgs; 2233 SmallVector<DictionaryAttr> resultAttrs; 2234 SmallVector<Type> resultTypes; 2235 bool isVariadic; 2236 2237 auto signatureLocation = parser.getCurrentLocation(); 2238 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2239 result.attributes) || 2240 function_interface_impl::parseFunctionSignature( 2241 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, 2242 resultAttrs)) 2243 return failure(); 2244 2245 SmallVector<Type> argTypes; 2246 for (auto &arg : entryArgs) 2247 argTypes.push_back(arg.type); 2248 auto type = 2249 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2250 function_interface_impl::VariadicFlag(isVariadic)); 2251 if (!type) 2252 return failure(); 2253 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2254 TypeAttr::get(type)); 2255 2256 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2257 return failure(); 2258 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2259 entryArgs, resultAttrs); 2260 2261 auto *body = result.addRegion(); 2262 OptionalParseResult parseResult = 2263 parser.parseOptionalRegion(*body, entryArgs); 2264 return failure(parseResult.hasValue() && failed(*parseResult)); 2265 } 2266 2267 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2268 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2269 // the external linkage since it is the default value. 2270 void LLVMFuncOp::print(OpAsmPrinter &p) { 2271 p << ' '; 2272 if (getLinkage() != LLVM::Linkage::External) 2273 p << stringifyLinkage(getLinkage()) << ' '; 2274 if (getCConv() != LLVM::CConv::C) 2275 p << stringifyCConv(getCConv()) << ' '; 2276 2277 p.printSymbolName(getName()); 2278 2279 LLVMFunctionType fnType = getFunctionType(); 2280 SmallVector<Type, 8> argTypes; 2281 SmallVector<Type, 1> resTypes; 2282 argTypes.reserve(fnType.getNumParams()); 2283 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2284 argTypes.push_back(fnType.getParamType(i)); 2285 2286 Type returnType = fnType.getReturnType(); 2287 if (!returnType.isa<LLVMVoidType>()) 2288 resTypes.push_back(returnType); 2289 2290 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2291 isVarArg(), resTypes); 2292 function_interface_impl::printFunctionAttributes( 2293 p, *this, argTypes.size(), resTypes.size(), 2294 {getLinkageAttrName(), getCConvAttrName()}); 2295 2296 // Print the body if this is not an external function. 2297 Region &body = getBody(); 2298 if (!body.empty()) { 2299 p << ' '; 2300 p.printRegion(body, /*printEntryBlockArgs=*/false, 2301 /*printBlockTerminators=*/true); 2302 } 2303 } 2304 2305 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2306 // - functions don't have 'common' linkage 2307 // - external functions have 'external' or 'extern_weak' linkage; 2308 // - vararg is (currently) only supported for external functions; 2309 LogicalResult LLVMFuncOp::verify() { 2310 if (getLinkage() == LLVM::Linkage::Common) 2311 return emitOpError() << "functions cannot have '" 2312 << stringifyLinkage(LLVM::Linkage::Common) 2313 << "' linkage"; 2314 2315 // Check to see if this function has a void return with a result attribute to 2316 // it. It isn't clear what semantics we would assign to that. 2317 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2318 !getResultAttrs(0).empty()) { 2319 return emitOpError() 2320 << "cannot attach result attributes to functions with a void return"; 2321 } 2322 2323 if (isExternal()) { 2324 if (getLinkage() != LLVM::Linkage::External && 2325 getLinkage() != LLVM::Linkage::ExternWeak) 2326 return emitOpError() << "external functions must have '" 2327 << stringifyLinkage(LLVM::Linkage::External) 2328 << "' or '" 2329 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2330 << "' linkage"; 2331 return success(); 2332 } 2333 2334 if (isVarArg()) 2335 return emitOpError("only external functions can be variadic"); 2336 2337 return success(); 2338 } 2339 2340 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2341 /// - entry block arguments are of LLVM types. 2342 LogicalResult LLVMFuncOp::verifyRegions() { 2343 if (isExternal()) 2344 return success(); 2345 2346 unsigned numArguments = getFunctionType().getNumParams(); 2347 Block &entryBlock = front(); 2348 for (unsigned i = 0; i < numArguments; ++i) { 2349 Type argType = entryBlock.getArgument(i).getType(); 2350 if (!isCompatibleType(argType)) 2351 return emitOpError("entry block argument #") 2352 << i << " is not of LLVM type"; 2353 } 2354 2355 return success(); 2356 } 2357 2358 //===----------------------------------------------------------------------===// 2359 // Verification for LLVM::ConstantOp. 2360 //===----------------------------------------------------------------------===// 2361 2362 LogicalResult LLVM::ConstantOp::verify() { 2363 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2364 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2365 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2366 !arrayType.getElementType().isInteger(8)) { 2367 return emitOpError() << "expected array type of " 2368 << sAttr.getValue().size() 2369 << " i8 elements for the string constant"; 2370 } 2371 return success(); 2372 } 2373 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2374 if (structType.getBody().size() != 2 || 2375 structType.getBody()[0] != structType.getBody()[1]) { 2376 return emitError() << "expected struct type with two elements of the " 2377 "same type, the type of a complex constant"; 2378 } 2379 2380 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2381 if (!arrayAttr || arrayAttr.size() != 2 || 2382 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2383 return emitOpError() << "expected array attribute with two elements, " 2384 "representing a complex constant"; 2385 } 2386 2387 Type elementType = structType.getBody()[0]; 2388 if (!elementType 2389 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2390 return emitError() 2391 << "expected struct element types to be floating point type or " 2392 "integer type"; 2393 } 2394 return success(); 2395 } 2396 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2397 return emitOpError() 2398 << "only supports integer, float, string or elements attributes"; 2399 return success(); 2400 } 2401 2402 // Constant op constant-folds to its value. 2403 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2404 2405 //===----------------------------------------------------------------------===// 2406 // Utility functions for parsing atomic ops 2407 //===----------------------------------------------------------------------===// 2408 2409 // Helper function to parse a keyword into the specified attribute named by 2410 // `attrName`. The keyword must match one of the string values defined by the 2411 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2412 // state. 2413 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2414 StringRef attrName) { 2415 SMLoc loc; 2416 StringRef keyword; 2417 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2418 return failure(); 2419 2420 // Replace the keyword `keyword` with an integer attribute. 2421 auto kind = symbolizeAtomicBinOp(keyword); 2422 if (!kind) { 2423 return parser.emitError(loc) 2424 << "'" << keyword << "' is an incorrect value of the '" << attrName 2425 << "' attribute"; 2426 } 2427 2428 auto value = static_cast<int64_t>(kind.getValue()); 2429 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2430 result.addAttribute(attrName, attr); 2431 2432 return success(); 2433 } 2434 2435 // Helper function to parse a keyword into the specified attribute named by 2436 // `attrName`. The keyword must match one of the string values defined by the 2437 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2438 // state. 2439 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2440 OperationState &result, 2441 StringRef attrName) { 2442 SMLoc loc; 2443 StringRef ordering; 2444 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2445 return failure(); 2446 2447 // Replace the keyword `ordering` with an integer attribute. 2448 auto kind = symbolizeAtomicOrdering(ordering); 2449 if (!kind) { 2450 return parser.emitError(loc) 2451 << "'" << ordering << "' is an incorrect value of the '" << attrName 2452 << "' attribute"; 2453 } 2454 2455 auto value = static_cast<int64_t>(kind.getValue()); 2456 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2457 result.addAttribute(attrName, attr); 2458 2459 return success(); 2460 } 2461 2462 //===----------------------------------------------------------------------===// 2463 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2464 //===----------------------------------------------------------------------===// 2465 2466 void AtomicRMWOp::print(OpAsmPrinter &p) { 2467 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2468 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2469 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2470 p << " : " << getRes().getType(); 2471 } 2472 2473 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2474 // attribute-dict? `:` type 2475 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2476 Type type; 2477 OpAsmParser::UnresolvedOperand ptr, val; 2478 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2479 parser.parseComma() || parser.parseOperand(val) || 2480 parseAtomicOrdering(parser, result, "ordering") || 2481 parser.parseOptionalAttrDict(result.attributes) || 2482 parser.parseColonType(type) || 2483 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2484 result.operands) || 2485 parser.resolveOperand(val, type, result.operands)) 2486 return failure(); 2487 2488 result.addTypes(type); 2489 return success(); 2490 } 2491 2492 LogicalResult AtomicRMWOp::verify() { 2493 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2494 auto valType = getVal().getType(); 2495 if (valType != ptrType.getElementType()) 2496 return emitOpError("expected LLVM IR element type for operand #0 to " 2497 "match type for operand #1"); 2498 auto resType = getRes().getType(); 2499 if (resType != valType) 2500 return emitOpError( 2501 "expected LLVM IR result type to match type for operand #1"); 2502 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2503 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2504 return emitOpError("expected LLVM IR floating point type"); 2505 } else if (getBinOp() == AtomicBinOp::xchg) { 2506 auto intType = valType.dyn_cast<IntegerType>(); 2507 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2508 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2509 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2510 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2511 !valType.isa<Float64Type>()) 2512 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2513 } else { 2514 auto intType = valType.dyn_cast<IntegerType>(); 2515 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2516 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2517 intBitWidth != 64) 2518 return emitOpError("expected LLVM IR integer type"); 2519 } 2520 2521 if (static_cast<unsigned>(getOrdering()) < 2522 static_cast<unsigned>(AtomicOrdering::monotonic)) 2523 return emitOpError() << "expected at least '" 2524 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2525 << "' ordering"; 2526 2527 return success(); 2528 } 2529 2530 //===----------------------------------------------------------------------===// 2531 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2532 //===----------------------------------------------------------------------===// 2533 2534 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2535 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2536 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2537 << stringifyAtomicOrdering(getFailureOrdering()); 2538 p.printOptionalAttrDict((*this)->getAttrs(), 2539 {"success_ordering", "failure_ordering"}); 2540 p << " : " << getVal().getType(); 2541 } 2542 2543 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2544 // keyword keyword attribute-dict? `:` type 2545 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2546 OperationState &result) { 2547 auto &builder = parser.getBuilder(); 2548 Type type; 2549 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2550 if (parser.parseOperand(ptr) || parser.parseComma() || 2551 parser.parseOperand(cmp) || parser.parseComma() || 2552 parser.parseOperand(val) || 2553 parseAtomicOrdering(parser, result, "success_ordering") || 2554 parseAtomicOrdering(parser, result, "failure_ordering") || 2555 parser.parseOptionalAttrDict(result.attributes) || 2556 parser.parseColonType(type) || 2557 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2558 result.operands) || 2559 parser.resolveOperand(cmp, type, result.operands) || 2560 parser.resolveOperand(val, type, result.operands)) 2561 return failure(); 2562 2563 auto boolType = IntegerType::get(builder.getContext(), 1); 2564 auto resultType = 2565 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2566 result.addTypes(resultType); 2567 2568 return success(); 2569 } 2570 2571 LogicalResult AtomicCmpXchgOp::verify() { 2572 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2573 if (!ptrType) 2574 return emitOpError("expected LLVM IR pointer type for operand #0"); 2575 auto cmpType = getCmp().getType(); 2576 auto valType = getVal().getType(); 2577 if (cmpType != ptrType.getElementType() || cmpType != valType) 2578 return emitOpError("expected LLVM IR element type for operand #0 to " 2579 "match type for all other operands"); 2580 auto intType = valType.dyn_cast<IntegerType>(); 2581 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2582 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2583 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2584 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2585 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2586 return emitOpError("unexpected LLVM IR type"); 2587 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2588 getFailureOrdering() < AtomicOrdering::monotonic) 2589 return emitOpError("ordering must be at least 'monotonic'"); 2590 if (getFailureOrdering() == AtomicOrdering::release || 2591 getFailureOrdering() == AtomicOrdering::acq_rel) 2592 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2593 return success(); 2594 } 2595 2596 //===----------------------------------------------------------------------===// 2597 // Printer, parser and verifier for LLVM::FenceOp. 2598 //===----------------------------------------------------------------------===// 2599 2600 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2601 // attribute-dict? 2602 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2603 StringAttr sScope; 2604 StringRef syncscopeKeyword = "syncscope"; 2605 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2606 if (parser.parseLParen() || 2607 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2608 parser.parseRParen()) 2609 return failure(); 2610 } else { 2611 result.addAttribute(syncscopeKeyword, 2612 parser.getBuilder().getStringAttr("")); 2613 } 2614 if (parseAtomicOrdering(parser, result, "ordering") || 2615 parser.parseOptionalAttrDict(result.attributes)) 2616 return failure(); 2617 return success(); 2618 } 2619 2620 void FenceOp::print(OpAsmPrinter &p) { 2621 StringRef syncscopeKeyword = "syncscope"; 2622 p << ' '; 2623 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2624 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2625 p << stringifyAtomicOrdering(getOrdering()); 2626 } 2627 2628 LogicalResult FenceOp::verify() { 2629 if (getOrdering() == AtomicOrdering::not_atomic || 2630 getOrdering() == AtomicOrdering::unordered || 2631 getOrdering() == AtomicOrdering::monotonic) 2632 return emitOpError("can be given only acquire, release, acq_rel, " 2633 "and seq_cst orderings"); 2634 return success(); 2635 } 2636 2637 //===----------------------------------------------------------------------===// 2638 // Folder for LLVM::BitcastOp 2639 //===----------------------------------------------------------------------===// 2640 2641 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2642 // bitcast(x : T0, T0) -> x 2643 if (getArg().getType() == getType()) 2644 return getArg(); 2645 // bitcast(bitcast(x : T0, T1), T0) -> x 2646 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2647 if (prev.getArg().getType() == getType()) 2648 return prev.getArg(); 2649 return {}; 2650 } 2651 2652 //===----------------------------------------------------------------------===// 2653 // Folder for LLVM::AddrSpaceCastOp 2654 //===----------------------------------------------------------------------===// 2655 2656 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2657 // addrcast(x : T0, T0) -> x 2658 if (getArg().getType() == getType()) 2659 return getArg(); 2660 // addrcast(addrcast(x : T0, T1), T0) -> x 2661 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2662 if (prev.getArg().getType() == getType()) 2663 return prev.getArg(); 2664 return {}; 2665 } 2666 2667 //===----------------------------------------------------------------------===// 2668 // Folder for LLVM::GEPOp 2669 //===----------------------------------------------------------------------===// 2670 2671 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2672 // gep %x:T, 0 -> %x 2673 if (getBase().getType() == getType() && getIndices().size() == 1 && 2674 matchPattern(getIndices()[0], m_Zero())) 2675 return getBase(); 2676 return {}; 2677 } 2678 2679 //===----------------------------------------------------------------------===// 2680 // LLVMDialect initialization, type parsing, and registration. 2681 //===----------------------------------------------------------------------===// 2682 2683 void LLVMDialect::initialize() { 2684 addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>(); 2685 2686 // clang-format off 2687 addTypes<LLVMVoidType, 2688 LLVMPPCFP128Type, 2689 LLVMX86MMXType, 2690 LLVMTokenType, 2691 LLVMLabelType, 2692 LLVMMetadataType, 2693 LLVMFunctionType, 2694 LLVMPointerType, 2695 LLVMFixedVectorType, 2696 LLVMScalableVectorType, 2697 LLVMArrayType, 2698 LLVMStructType>(); 2699 // clang-format on 2700 addOperations< 2701 #define GET_OP_LIST 2702 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2703 , 2704 #define GET_OP_LIST 2705 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 2706 >(); 2707 2708 // Support unknown operations because not all LLVM operations are registered. 2709 allowUnknownOperations(); 2710 } 2711 2712 #define GET_OP_CLASSES 2713 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2714 2715 /// Parse a type registered to this dialect. 2716 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2717 return detail::parseType(parser); 2718 } 2719 2720 /// Print a type registered to this dialect. 2721 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2722 return detail::printType(type, os); 2723 } 2724 2725 LogicalResult LLVMDialect::verifyDataLayoutString( 2726 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2727 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2728 llvm::DataLayout::parse(descr); 2729 if (maybeDataLayout) 2730 return success(); 2731 2732 std::string message; 2733 llvm::raw_string_ostream messageStream(message); 2734 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2735 reportError("invalid data layout descriptor: " + messageStream.str()); 2736 return failure(); 2737 } 2738 2739 /// Verify LLVM dialect attributes. 2740 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2741 NamedAttribute attr) { 2742 // If the `llvm.loop` attribute is present, enforce the following structure, 2743 // which the module translation can assume. 2744 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2745 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2746 if (!loopAttr) 2747 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2748 << "' to be a dictionary attribute"; 2749 Optional<NamedAttribute> parallelAccessGroup = 2750 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2751 if (parallelAccessGroup.hasValue()) { 2752 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2753 if (!accessGroups) 2754 return op->emitOpError() 2755 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2756 << "' to be an array attribute"; 2757 for (Attribute attr : accessGroups) { 2758 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2759 if (!accessGroupRef) 2760 return op->emitOpError() 2761 << "expected '" << attr << "' to be a symbol reference"; 2762 StringAttr metadataName = accessGroupRef.getRootReference(); 2763 auto metadataOp = 2764 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2765 op->getParentOp(), metadataName); 2766 if (!metadataOp) 2767 return op->emitOpError() 2768 << "expected '" << attr << "' to reference a metadata op"; 2769 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2770 Operation *accessGroupOp = 2771 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2772 if (!accessGroupOp) 2773 return op->emitOpError() 2774 << "expected '" << attr << "' to reference an access_group op"; 2775 } 2776 } 2777 2778 Optional<NamedAttribute> loopOptions = 2779 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2780 if (loopOptions.hasValue() && 2781 !loopOptions->getValue().isa<LoopOptionsAttr>()) 2782 return op->emitOpError() 2783 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2784 << "' to be a `loopopts` attribute"; 2785 } 2786 2787 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2788 return op->emitOpError() 2789 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2790 << "' is permitted only in argument or result attributes"; 2791 } 2792 2793 // If the data layout attribute is present, it must use the LLVM data layout 2794 // syntax. Try parsing it and report errors in case of failure. Users of this 2795 // attribute may assume it is well-formed and can pass it to the (asserting) 2796 // llvm::DataLayout constructor. 2797 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2798 return success(); 2799 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2800 return verifyDataLayoutString( 2801 stringAttr.getValue(), 2802 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2803 2804 return op->emitOpError() << "expected '" 2805 << LLVM::LLVMDialect::getDataLayoutAttrName() 2806 << "' to be a string attributes"; 2807 } 2808 2809 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2810 Type annotatedType) { 2811 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2812 if (!structType) { 2813 const auto emitIncorrectAnnotatedType = [&op]() { 2814 return op->emitError() 2815 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2816 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2817 }; 2818 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2819 if (!ptrType) 2820 return emitIncorrectAnnotatedType(); 2821 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2822 if (!structType) 2823 return emitIncorrectAnnotatedType(); 2824 } 2825 2826 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2827 if (!arrAttrs) 2828 return op->emitError() << "expected '" 2829 << LLVMDialect::getStructAttrsAttrName() 2830 << "' to be an array attribute"; 2831 2832 if (structType.getBody().size() != arrAttrs.size()) 2833 return op->emitError() 2834 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2835 << "' must match the size of the annotated '!llvm.struct'"; 2836 return success(); 2837 } 2838 2839 static LogicalResult verifyFuncOpInterfaceStructAttr( 2840 Operation *op, Attribute attr, 2841 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) { 2842 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2843 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2844 return op->emitError() << "expected '" 2845 << LLVMDialect::getStructAttrsAttrName() 2846 << "' to be used on function-like operations"; 2847 } 2848 2849 /// Verify LLVMIR function argument attributes. 2850 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2851 unsigned regionIdx, 2852 unsigned argIdx, 2853 NamedAttribute argAttr) { 2854 // Check that llvm.noalias is a unit attribute. 2855 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2856 !argAttr.getValue().isa<UnitAttr>()) 2857 return op->emitError() 2858 << "expected llvm.noalias argument attribute to be a unit attribute"; 2859 // Check that llvm.align is an integer attribute. 2860 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2861 !argAttr.getValue().isa<IntegerAttr>()) 2862 return op->emitError() 2863 << "llvm.align argument attribute of non integer type"; 2864 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2865 return verifyFuncOpInterfaceStructAttr( 2866 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2867 return funcOp.getArgumentTypes()[argIdx]; 2868 }); 2869 } 2870 return success(); 2871 } 2872 2873 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2874 unsigned regionIdx, 2875 unsigned resIdx, 2876 NamedAttribute resAttr) { 2877 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2878 return verifyFuncOpInterfaceStructAttr( 2879 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2880 return funcOp.getResultTypes()[resIdx]; 2881 }); 2882 } 2883 return success(); 2884 } 2885 2886 //===----------------------------------------------------------------------===// 2887 // Utility functions. 2888 //===----------------------------------------------------------------------===// 2889 2890 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2891 StringRef name, StringRef value, 2892 LLVM::Linkage linkage) { 2893 assert(builder.getInsertionBlock() && 2894 builder.getInsertionBlock()->getParentOp() && 2895 "expected builder to point to a block constrained in an op"); 2896 auto module = 2897 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2898 assert(module && "builder points to an op outside of a module"); 2899 2900 // Create the global at the entry of the module. 2901 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2902 MLIRContext *ctx = builder.getContext(); 2903 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2904 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2905 loc, type, /*isConstant=*/true, linkage, name, 2906 builder.getStringAttr(value), /*alignment=*/0); 2907 2908 // Get the pointer to the first character in the global string. 2909 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2910 Value cst0 = builder.create<LLVM::ConstantOp>( 2911 loc, IntegerType::get(ctx, 64), 2912 builder.getIntegerAttr(builder.getIndexType(), 0)); 2913 return builder.create<LLVM::GEPOp>( 2914 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2915 ValueRange{cst0, cst0}); 2916 } 2917 2918 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2919 return op->hasTrait<OpTrait::SymbolTable>() && 2920 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2921 } 2922 2923 void FMFAttr::print(AsmPrinter &printer) const { 2924 printer << "<"; 2925 printer << stringifyFastmathFlags(this->getFlags()); 2926 printer << ">"; 2927 } 2928 2929 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2930 if (failed(parser.parseLess())) 2931 return {}; 2932 2933 FastmathFlags flags = {}; 2934 if (failed(parser.parseOptionalGreater())) { 2935 auto parseFlags = [&]() -> ParseResult { 2936 StringRef elemName; 2937 if (failed(parser.parseKeyword(&elemName))) 2938 return failure(); 2939 2940 auto elem = symbolizeFastmathFlags(elemName); 2941 if (!elem) 2942 return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2943 << elemName; 2944 2945 flags = flags | *elem; 2946 return success(); 2947 }; 2948 if (failed(parser.parseCommaSeparatedList(parseFlags)) || 2949 parser.parseGreater()) 2950 return {}; 2951 } 2952 2953 return FMFAttr::get(parser.getContext(), flags); 2954 } 2955 2956 void LinkageAttr::print(AsmPrinter &printer) const { 2957 printer << "<"; 2958 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2959 printer << stringifyEnum(getLinkage()); 2960 else 2961 printer << static_cast<uint64_t>(getLinkage()); 2962 printer << ">"; 2963 } 2964 2965 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2966 StringRef elemName; 2967 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2968 parser.parseGreater()) 2969 return {}; 2970 auto elem = linkage::symbolizeLinkage(elemName); 2971 if (!elem) { 2972 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 2973 return {}; 2974 } 2975 Linkage linkage = *elem; 2976 return LinkageAttr::get(parser.getContext(), linkage); 2977 } 2978 2979 void CConvAttr::print(AsmPrinter &printer) const { 2980 printer << "<"; 2981 if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv()) 2982 printer << stringifyEnum(getCallingConv()); 2983 else 2984 printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv()); 2985 printer << ">"; 2986 } 2987 2988 Attribute CConvAttr::parse(AsmParser &parser, Type type) { 2989 StringRef convName; 2990 2991 if (parser.parseLess() || parser.parseKeyword(&convName) || 2992 parser.parseGreater()) 2993 return {}; 2994 auto cconv = cconv::symbolizeCConv(convName); 2995 if (!cconv) { 2996 parser.emitError(parser.getNameLoc(), "unknown calling convention: ") 2997 << convName; 2998 return {}; 2999 } 3000 CConv cconvVal = *cconv; 3001 return CConvAttr::get(parser.getContext(), cconvVal); 3002 } 3003 3004 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 3005 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 3006 3007 template <typename T> 3008 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 3009 Optional<T> value) { 3010 auto option = llvm::find_if( 3011 options, [tag](auto option) { return option.first == tag; }); 3012 if (option != options.end()) { 3013 if (value.hasValue()) 3014 option->second = *value; 3015 else 3016 options.erase(option); 3017 } else { 3018 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 3019 } 3020 return *this; 3021 } 3022 3023 LoopOptionsAttrBuilder & 3024 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 3025 return setOption(LoopOptionCase::disable_licm, value); 3026 } 3027 3028 /// Set the `interleave_count` option to the provided value. If no value 3029 /// is provided the option is deleted. 3030 LoopOptionsAttrBuilder & 3031 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 3032 return setOption(LoopOptionCase::interleave_count, count); 3033 } 3034 3035 /// Set the `disable_unroll` option to the provided value. If no value 3036 /// is provided the option is deleted. 3037 LoopOptionsAttrBuilder & 3038 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 3039 return setOption(LoopOptionCase::disable_unroll, value); 3040 } 3041 3042 /// Set the `disable_pipeline` option to the provided value. If no value 3043 /// is provided the option is deleted. 3044 LoopOptionsAttrBuilder & 3045 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 3046 return setOption(LoopOptionCase::disable_pipeline, value); 3047 } 3048 3049 /// Set the `pipeline_initiation_interval` option to the provided value. 3050 /// If no value is provided the option is deleted. 3051 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 3052 Optional<uint64_t> count) { 3053 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 3054 } 3055 3056 template <typename T> 3057 static Optional<T> 3058 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 3059 LoopOptionCase option) { 3060 auto it = 3061 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 3062 return optionPair.first < option; 3063 }); 3064 if (it == options.end()) 3065 return {}; 3066 return static_cast<T>(it->second); 3067 } 3068 3069 Optional<bool> LoopOptionsAttr::disableUnroll() { 3070 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 3071 } 3072 3073 Optional<bool> LoopOptionsAttr::disableLICM() { 3074 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 3075 } 3076 3077 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 3078 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 3079 } 3080 3081 /// Build the LoopOptions Attribute from a sorted array of individual options. 3082 LoopOptionsAttr LoopOptionsAttr::get( 3083 MLIRContext *context, 3084 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 3085 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 3086 "LoopOptionsAttr ctor expects a sorted options array"); 3087 return Base::get(context, sortedOptions); 3088 } 3089 3090 /// Build the LoopOptions Attribute from a sorted array of individual options. 3091 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 3092 LoopOptionsAttrBuilder &optionBuilders) { 3093 llvm::sort(optionBuilders.options, llvm::less_first()); 3094 return Base::get(context, optionBuilders.options); 3095 } 3096 3097 void LoopOptionsAttr::print(AsmPrinter &printer) const { 3098 printer << "<"; 3099 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 3100 printer << stringifyEnum(option.first) << " = "; 3101 switch (option.first) { 3102 case LoopOptionCase::disable_licm: 3103 case LoopOptionCase::disable_unroll: 3104 case LoopOptionCase::disable_pipeline: 3105 printer << (option.second ? "true" : "false"); 3106 break; 3107 case LoopOptionCase::interleave_count: 3108 case LoopOptionCase::pipeline_initiation_interval: 3109 printer << option.second; 3110 break; 3111 } 3112 }); 3113 printer << ">"; 3114 } 3115 3116 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 3117 if (failed(parser.parseLess())) 3118 return {}; 3119 3120 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 3121 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 3122 auto parseLoopOptions = [&]() -> ParseResult { 3123 StringRef optionName; 3124 if (parser.parseKeyword(&optionName)) 3125 return failure(); 3126 3127 auto option = symbolizeLoopOptionCase(optionName); 3128 if (!option) 3129 return parser.emitError(parser.getNameLoc(), "unknown loop option: ") 3130 << optionName; 3131 if (!seenOptions.insert(*option).second) 3132 return parser.emitError(parser.getNameLoc(), "loop option present twice"); 3133 if (failed(parser.parseEqual())) 3134 return failure(); 3135 3136 int64_t value; 3137 switch (*option) { 3138 case LoopOptionCase::disable_licm: 3139 case LoopOptionCase::disable_unroll: 3140 case LoopOptionCase::disable_pipeline: 3141 if (succeeded(parser.parseOptionalKeyword("true"))) 3142 value = 1; 3143 else if (succeeded(parser.parseOptionalKeyword("false"))) 3144 value = 0; 3145 else { 3146 return parser.emitError(parser.getNameLoc(), 3147 "expected boolean value 'true' or 'false'"); 3148 } 3149 break; 3150 case LoopOptionCase::interleave_count: 3151 case LoopOptionCase::pipeline_initiation_interval: 3152 if (failed(parser.parseInteger(value))) 3153 return parser.emitError(parser.getNameLoc(), "expected integer value"); 3154 break; 3155 } 3156 options.push_back(std::make_pair(*option, value)); 3157 return success(); 3158 }; 3159 if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater()) 3160 return {}; 3161 3162 llvm::sort(options, llvm::less_first()); 3163 return get(parser.getContext(), options); 3164 } 3165