1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the LLVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "TypeDetail.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/DialectImplementation.h" 20 #include "mlir/IR/FunctionImplementation.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/Matchers.h" 23 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/Bitcode/BitcodeReader.h" 28 #include "llvm/Bitcode/BitcodeWriter.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Error.h" 33 #include "llvm/Support/Mutex.h" 34 #include "llvm/Support/SourceMgr.h" 35 36 #include <numeric> 37 38 using namespace mlir; 39 using namespace mlir::LLVM; 40 using mlir::LLVM::cconv::getMaxEnumValForCConv; 41 using mlir::LLVM::linkage::getMaxEnumValForLinkage; 42 43 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" 44 45 static constexpr const char kVolatileAttrName[] = "volatile_"; 46 static constexpr const char kNonTemporalAttrName[] = "nontemporal"; 47 static constexpr const char kElemTypeAttrName[] = "elem_type"; 48 49 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 50 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" 51 #define GET_ATTRDEF_CLASSES 52 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc" 53 54 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { 55 SmallVector<NamedAttribute, 8> filteredAttrs( 56 llvm::make_filter_range(attrs, [&](NamedAttribute attr) { 57 if (attr.getName() == "fastmathFlags") { 58 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); 59 return defAttr != attr.getValue(); 60 } 61 return true; 62 })); 63 return filteredAttrs; 64 } 65 66 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser, 67 NamedAttrList &result) { 68 return parser.parseOptionalAttrDict(result); 69 } 70 71 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, 72 DictionaryAttr attrs) { 73 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue())); 74 } 75 76 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and 77 /// fully defined llvm.func. 78 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, 79 Operation *op, 80 SymbolTableCollection &symbolTable) { 81 StringRef name = symbol.getValue(); 82 auto func = 83 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); 84 if (!func) 85 return op->emitOpError("'") 86 << name << "' does not reference a valid LLVM function"; 87 if (func.isExternal()) 88 return op->emitOpError("'") << name << "' does not have a definition"; 89 return success(); 90 } 91 92 //===----------------------------------------------------------------------===// 93 // Printing, parsing and builder for LLVM::CmpOp. 94 //===----------------------------------------------------------------------===// 95 96 void ICmpOp::build(OpBuilder &builder, OperationState &result, 97 ICmpPredicate predicate, Value lhs, Value rhs) { 98 auto boolType = IntegerType::get(lhs.getType().getContext(), 1); 99 if (LLVM::isCompatibleVectorType(lhs.getType()) || 100 LLVM::isCompatibleVectorType(rhs.getType())) { 101 int64_t numLHSElements = 1, numRHSElements = 1; 102 if (LLVM::isCompatibleVectorType(lhs.getType())) 103 numLHSElements = 104 LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); 105 if (LLVM::isCompatibleVectorType(rhs.getType())) 106 numRHSElements = 107 LLVM::getVectorNumElements(rhs.getType()).getFixedValue(); 108 build(builder, result, 109 VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType), 110 predicate, lhs, rhs); 111 } else { 112 build(builder, result, boolType, predicate, lhs, rhs); 113 } 114 } 115 116 void ICmpOp::print(OpAsmPrinter &p) { 117 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) 118 << ", " << getOperand(1); 119 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); 120 p << " : " << getLhs().getType(); 121 } 122 123 void FCmpOp::print(OpAsmPrinter &p) { 124 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) 125 << ", " << getOperand(1); 126 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); 127 p << " : " << getLhs().getType(); 128 } 129 130 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 131 // attribute-dict? `:` type 132 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 133 // attribute-dict? `:` type 134 template <typename CmpPredicateType> 135 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { 136 Builder &builder = parser.getBuilder(); 137 138 StringAttr predicateAttr; 139 OpAsmParser::UnresolvedOperand lhs, rhs; 140 Type type; 141 SMLoc predicateLoc, trailingTypeLoc; 142 if (parser.getCurrentLocation(&predicateLoc) || 143 parser.parseAttribute(predicateAttr, "predicate", result.attributes) || 144 parser.parseOperand(lhs) || parser.parseComma() || 145 parser.parseOperand(rhs) || 146 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 147 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 148 parser.resolveOperand(lhs, type, result.operands) || 149 parser.resolveOperand(rhs, type, result.operands)) 150 return failure(); 151 152 // Replace the string attribute `predicate` with an integer attribute. 153 int64_t predicateValue = 0; 154 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 155 Optional<ICmpPredicate> predicate = 156 symbolizeICmpPredicate(predicateAttr.getValue()); 157 if (!predicate) 158 return parser.emitError(predicateLoc) 159 << "'" << predicateAttr.getValue() 160 << "' is an incorrect value of the 'predicate' attribute"; 161 predicateValue = static_cast<int64_t>(*predicate); 162 } else { 163 Optional<FCmpPredicate> predicate = 164 symbolizeFCmpPredicate(predicateAttr.getValue()); 165 if (!predicate) 166 return parser.emitError(predicateLoc) 167 << "'" << predicateAttr.getValue() 168 << "' is an incorrect value of the 'predicate' attribute"; 169 predicateValue = static_cast<int64_t>(*predicate); 170 } 171 172 result.attributes.set("predicate", 173 parser.getBuilder().getI64IntegerAttr(predicateValue)); 174 175 // The result type is either i1 or a vector type <? x i1> if the inputs are 176 // vectors. 177 Type resultType = IntegerType::get(builder.getContext(), 1); 178 if (!isCompatibleType(type)) 179 return parser.emitError(trailingTypeLoc, 180 "expected LLVM dialect-compatible type"); 181 if (LLVM::isCompatibleVectorType(type)) { 182 if (LLVM::isScalableVectorType(type)) { 183 resultType = LLVM::getVectorType( 184 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), 185 /*isScalable=*/true); 186 } else { 187 resultType = LLVM::getVectorType( 188 resultType, LLVM::getVectorNumElements(type).getFixedValue(), 189 /*isScalable=*/false); 190 } 191 } 192 193 result.addTypes({resultType}); 194 return success(); 195 } 196 197 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { 198 return parseCmpOp<ICmpPredicate>(parser, result); 199 } 200 201 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { 202 return parseCmpOp<FCmpPredicate>(parser, result); 203 } 204 205 //===----------------------------------------------------------------------===// 206 // Printing, parsing and verification for LLVM::AllocaOp. 207 //===----------------------------------------------------------------------===// 208 209 void AllocaOp::print(OpAsmPrinter &p) { 210 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType(); 211 if (!elemTy) 212 elemTy = *getElemType(); 213 214 auto funcTy = 215 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); 216 217 p << ' ' << getArraySize() << " x " << elemTy; 218 if (getAlignment() && *getAlignment() != 0) 219 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName}); 220 else 221 p.printOptionalAttrDict((*this)->getAttrs(), 222 {"alignment", kElemTypeAttrName}); 223 p << " : " << funcTy; 224 } 225 226 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 227 // `:` type `,` type 228 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { 229 OpAsmParser::UnresolvedOperand arraySize; 230 Type type, elemType; 231 SMLoc trailingTypeLoc; 232 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || 233 parser.parseType(elemType) || 234 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 235 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 236 return failure(); 237 238 Optional<NamedAttribute> alignmentAttr = 239 result.attributes.getNamed("alignment"); 240 if (alignmentAttr.has_value()) { 241 auto alignmentInt = 242 alignmentAttr.value().getValue().dyn_cast<IntegerAttr>(); 243 if (!alignmentInt) 244 return parser.emitError(parser.getNameLoc(), 245 "expected integer alignment"); 246 if (alignmentInt.getValue().isNullValue()) 247 result.attributes.erase("alignment"); 248 } 249 250 // Extract the result type from the trailing function type. 251 auto funcType = type.dyn_cast<FunctionType>(); 252 if (!funcType || funcType.getNumInputs() != 1 || 253 funcType.getNumResults() != 1) 254 return parser.emitError( 255 trailingTypeLoc, 256 "expected trailing function type with one argument and one result"); 257 258 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) 259 return failure(); 260 261 Type resultType = funcType.getResult(0); 262 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) { 263 if (ptrResultType.isOpaque()) 264 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); 265 } 266 267 result.addTypes({funcType.getResult(0)}); 268 return success(); 269 } 270 271 /// Checks that the elemental type is present in either the pointer type or 272 /// the attribute, but not both. 273 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType, 274 Optional<Type> ptrElementType) { 275 if (ptrType.isOpaque() && !ptrElementType.has_value()) { 276 return op->emitOpError() << "expected '" << kElemTypeAttrName 277 << "' attribute if opaque pointer type is used"; 278 } 279 if (!ptrType.isOpaque() && ptrElementType.has_value()) { 280 return op->emitOpError() 281 << "unexpected '" << kElemTypeAttrName 282 << "' attribute when non-opaque pointer type is used"; 283 } 284 return success(); 285 } 286 287 LogicalResult AllocaOp::verify() { 288 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(), 289 getElemType()); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // LLVM::BrOp 294 //===----------------------------------------------------------------------===// 295 296 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { 297 assert(index == 0 && "invalid successor index"); 298 return SuccessorOperands(getDestOperandsMutable()); 299 } 300 301 //===----------------------------------------------------------------------===// 302 // LLVM::CondBrOp 303 //===----------------------------------------------------------------------===// 304 305 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { 306 assert(index < getNumSuccessors() && "invalid successor index"); 307 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() 308 : getFalseDestOperandsMutable()); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // LLVM::SwitchOp 313 //===----------------------------------------------------------------------===// 314 315 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, 316 Block *defaultDestination, ValueRange defaultOperands, 317 ArrayRef<int32_t> caseValues, BlockRange caseDestinations, 318 ArrayRef<ValueRange> caseOperands, 319 ArrayRef<int32_t> branchWeights) { 320 ElementsAttr caseValuesAttr; 321 if (!caseValues.empty()) 322 caseValuesAttr = builder.getI32VectorAttr(caseValues); 323 324 ElementsAttr weightsAttr; 325 if (!branchWeights.empty()) 326 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); 327 328 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, 329 weightsAttr, defaultDestination, caseDestinations); 330 } 331 332 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? 333 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? 334 static ParseResult parseSwitchOpCases( 335 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, 336 SmallVectorImpl<Block *> &caseDestinations, 337 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, 338 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { 339 SmallVector<APInt> values; 340 unsigned bitWidth = flagType.getIntOrFloatBitWidth(); 341 do { 342 int64_t value = 0; 343 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); 344 if (values.empty() && !integerParseResult.hasValue()) 345 return success(); 346 347 if (!integerParseResult.hasValue() || integerParseResult.getValue()) 348 return failure(); 349 values.push_back(APInt(bitWidth, value)); 350 351 Block *destination; 352 SmallVector<OpAsmParser::UnresolvedOperand> operands; 353 SmallVector<Type> operandTypes; 354 if (parser.parseColon() || parser.parseSuccessor(destination)) 355 return failure(); 356 if (!parser.parseOptionalLParen()) { 357 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, 358 /*allowResultNumber=*/false) || 359 parser.parseColonTypeList(operandTypes) || parser.parseRParen()) 360 return failure(); 361 } 362 caseDestinations.push_back(destination); 363 caseOperands.emplace_back(operands); 364 caseOperandTypes.emplace_back(operandTypes); 365 } while (!parser.parseOptionalComma()); 366 367 ShapedType caseValueType = 368 VectorType::get(static_cast<int64_t>(values.size()), flagType); 369 caseValues = DenseIntElementsAttr::get(caseValueType, values); 370 return success(); 371 } 372 373 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, 374 ElementsAttr caseValues, 375 SuccessorRange caseDestinations, 376 OperandRangeRange caseOperands, 377 const TypeRangeRange &caseOperandTypes) { 378 if (!caseValues) 379 return; 380 381 size_t index = 0; 382 llvm::interleave( 383 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations), 384 [&](auto i) { 385 p << " "; 386 p << std::get<0>(i).getLimitedValue(); 387 p << ": "; 388 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); 389 }, 390 [&] { 391 p << ','; 392 p.printNewline(); 393 }); 394 p.printNewline(); 395 } 396 397 LogicalResult SwitchOp::verify() { 398 if ((!getCaseValues() && !getCaseDestinations().empty()) || 399 (getCaseValues() && 400 getCaseValues()->size() != 401 static_cast<int64_t>(getCaseDestinations().size()))) 402 return emitOpError("expects number of case values to match number of " 403 "case destinations"); 404 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) 405 return emitError("expects number of branch weights to match number of " 406 "successors: ") 407 << getBranchWeights()->size() << " vs " << getNumSuccessors(); 408 return success(); 409 } 410 411 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { 412 assert(index < getNumSuccessors() && "invalid successor index"); 413 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() 414 : getCaseOperandsMutable(index - 1)); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // Code for LLVM::GEPOp. 419 //===----------------------------------------------------------------------===// 420 421 constexpr int GEPOp::kDynamicIndex; 422 423 namespace { 424 /// Base class for llvm::Error related to GEP index. 425 class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> { 426 protected: 427 unsigned indexPos; 428 429 public: 430 static char ID; 431 432 std::error_code convertToErrorCode() const override { 433 return llvm::inconvertibleErrorCode(); 434 } 435 436 explicit GEPIndexError(unsigned pos) : indexPos(pos) {} 437 }; 438 439 /// llvm::Error for out-of-bound GEP index. 440 struct GEPIndexOutOfBoundError 441 : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> { 442 static char ID; 443 444 using ErrorInfo::ErrorInfo; 445 446 void log(llvm::raw_ostream &os) const override { 447 os << "index " << indexPos << " indexing a struct is out of bounds"; 448 } 449 }; 450 451 /// llvm::Error for non-static GEP index indexing a struct. 452 struct GEPStaticIndexError 453 : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> { 454 static char ID; 455 456 using ErrorInfo::ErrorInfo; 457 458 void log(llvm::raw_ostream &os) const override { 459 os << "expected index " << indexPos << " indexing a struct " 460 << "to be constant"; 461 } 462 }; 463 } // end anonymous namespace 464 465 char GEPIndexError::ID = 0; 466 char GEPIndexOutOfBoundError::ID = 0; 467 char GEPStaticIndexError::ID = 0; 468 469 /// For the given `structIndices` and `indices`, check if they're complied 470 /// with `baseGEPType`, especially check against LLVMStructTypes nested within, 471 /// and refine/promote struct index from `indices` to `updatedStructIndices` 472 /// if the latter argument is not null. 473 static llvm::Error 474 recordStructIndices(Type baseGEPType, unsigned indexPos, 475 ArrayRef<int32_t> structIndices, ValueRange indices, 476 SmallVectorImpl<int32_t> *updatedStructIndices, 477 SmallVectorImpl<Value> *remainingIndices) { 478 if (indexPos >= structIndices.size()) 479 // Stop searching 480 return llvm::Error::success(); 481 482 int32_t gepIndex = structIndices[indexPos]; 483 bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; 484 485 unsigned dynamicIndexPos = indexPos; 486 if (!isStaticIndex) 487 dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), 488 LLVM::GEPOp::kDynamicIndex) - 489 1; 490 491 return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType) 492 .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error { 493 // We don't always want to refine the index (e.g. when performing 494 // verification), so we only refine when updatedStructIndices is not 495 // null. 496 if (!isStaticIndex && updatedStructIndices) { 497 // Try to refine. 498 APInt staticIndexValue; 499 isStaticIndex = matchPattern(indices[dynamicIndexPos], 500 m_ConstantInt(&staticIndexValue)); 501 if (isStaticIndex) { 502 assert(staticIndexValue.getBitWidth() <= 64 && 503 llvm::isInt<32>(staticIndexValue.getLimitedValue()) && 504 "struct index can't fit within int32_t"); 505 gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue()); 506 } 507 } 508 if (!isStaticIndex) 509 return llvm::make_error<GEPStaticIndexError>(indexPos); 510 511 ArrayRef<Type> elementTypes = structType.getBody(); 512 if (gepIndex < 0 || 513 static_cast<size_t>(gepIndex) >= elementTypes.size()) 514 return llvm::make_error<GEPIndexOutOfBoundError>(indexPos); 515 516 if (updatedStructIndices) 517 (*updatedStructIndices)[indexPos] = gepIndex; 518 519 // Instead of recusively going into every children types, we only 520 // dive into the one indexed by gepIndex. 521 return recordStructIndices(elementTypes[gepIndex], indexPos + 1, 522 structIndices, indices, updatedStructIndices, 523 remainingIndices); 524 }) 525 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, 526 LLVMArrayType>([&](auto containerType) -> llvm::Error { 527 // Currently we don't refine non-struct index even if it's static. 528 if (remainingIndices) 529 remainingIndices->push_back(indices[dynamicIndexPos]); 530 return recordStructIndices(containerType.getElementType(), indexPos + 1, 531 structIndices, indices, updatedStructIndices, 532 remainingIndices); 533 }) 534 .Default( 535 [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); 536 } 537 538 /// Driver function around `recordStructIndices`. Note that we always check 539 /// from the second GEP index since the first one is always dynamic. 540 static llvm::Error 541 findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices, 542 ValueRange indices, 543 SmallVectorImpl<int32_t> *updatedStructIndices = nullptr, 544 SmallVectorImpl<Value> *remainingIndices = nullptr) { 545 if (remainingIndices) 546 // The first GEP index is always dynamic. 547 remainingIndices->push_back(indices[0]); 548 return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, 549 indices, updatedStructIndices, remainingIndices); 550 } 551 552 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 553 Value basePtr, ValueRange operands, 554 ArrayRef<NamedAttribute> attributes) { 555 build(builder, result, resultType, basePtr, operands, 556 SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes); 557 } 558 559 /// Returns the elemental type of any LLVM-compatible vector type or self. 560 static Type extractVectorElementType(Type type) { 561 if (auto vectorType = type.dyn_cast<VectorType>()) 562 return vectorType.getElementType(); 563 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>()) 564 return scalableVectorType.getElementType(); 565 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>()) 566 return fixedVectorType.getElementType(); 567 return type; 568 } 569 570 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 571 Type elementType, Value basePtr, ValueRange indices, 572 ArrayRef<NamedAttribute> attributes) { 573 build(builder, result, resultType, elementType, basePtr, indices, 574 SmallVector<int32_t>(indices.size(), kDynamicIndex), attributes); 575 } 576 577 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 578 Value basePtr, ValueRange indices, 579 ArrayRef<int32_t> structIndices, 580 ArrayRef<NamedAttribute> attributes) { 581 auto ptrType = 582 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>(); 583 assert(!ptrType.isOpaque() && 584 "expected non-opaque pointer, provide elementType explicitly when " 585 "opaque pointers are used"); 586 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, 587 structIndices, attributes); 588 } 589 590 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, 591 Type elementType, Value basePtr, ValueRange indices, 592 ArrayRef<int32_t> structIndices, 593 ArrayRef<NamedAttribute> attributes) { 594 SmallVector<Value> remainingIndices; 595 SmallVector<int32_t> updatedStructIndices(structIndices.begin(), 596 structIndices.end()); 597 if (llvm::Error err = 598 findStructIndices(elementType, structIndices, indices, 599 &updatedStructIndices, &remainingIndices)) 600 llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); 601 602 assert(remainingIndices.size() == static_cast<size_t>(llvm::count( 603 updatedStructIndices, kDynamicIndex)) && 604 "expected as many index operands as dynamic index attr elements"); 605 606 result.addTypes(resultType); 607 result.addAttributes(attributes); 608 result.addAttribute("structIndices", 609 builder.getI32TensorAttr(updatedStructIndices)); 610 if (extractVectorElementType(basePtr.getType()) 611 .cast<LLVMPointerType>() 612 .isOpaque()) 613 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); 614 result.addOperands(basePtr); 615 result.addOperands(remainingIndices); 616 } 617 618 static ParseResult 619 parseGEPIndices(OpAsmParser &parser, 620 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, 621 DenseIntElementsAttr &structIndices) { 622 SmallVector<int32_t> constantIndices; 623 624 auto idxParser = [&]() -> ParseResult { 625 int32_t constantIndex; 626 OptionalParseResult parsedInteger = 627 parser.parseOptionalInteger(constantIndex); 628 if (parsedInteger.hasValue()) { 629 if (failed(parsedInteger.getValue())) 630 return failure(); 631 constantIndices.push_back(constantIndex); 632 return success(); 633 } 634 635 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); 636 return parser.parseOperand(indices.emplace_back()); 637 }; 638 if (parser.parseCommaSeparatedList(idxParser)) 639 return failure(); 640 641 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); 642 return success(); 643 } 644 645 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, 646 OperandRange indices, 647 DenseIntElementsAttr structIndices) { 648 unsigned operandIdx = 0; 649 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer, 650 [&](int32_t cst) { 651 if (cst == LLVM::GEPOp::kDynamicIndex) 652 printer.printOperand(indices[operandIdx++]); 653 else 654 printer << cst; 655 }); 656 } 657 658 LogicalResult LLVM::GEPOp::verify() { 659 if (failed(verifyOpaquePtr( 660 getOperation(), 661 extractVectorElementType(getType()).cast<LLVMPointerType>(), 662 getElemType()))) 663 return failure(); 664 665 auto structIndexRange = getStructIndices().getValues<int32_t>(); 666 // structIndexRange is a kind of iterator, which cannot be converted 667 // to ArrayRef directly. 668 SmallVector<int32_t> structIndices(structIndexRange.size()); 669 for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size())) 670 structIndices[i] = structIndexRange[i]; 671 if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices, 672 getIndices())) 673 return emitOpError() << llvm::toString(std::move(err)); 674 675 return success(); 676 } 677 678 Type LLVM::GEPOp::getSourceElementType() { 679 if (Optional<Type> elemType = getElemType()) 680 return *elemType; 681 682 return extractVectorElementType(getBase().getType()) 683 .cast<LLVMPointerType>() 684 .getElementType(); 685 } 686 687 //===----------------------------------------------------------------------===// 688 // Builder, printer and parser for for LLVM::LoadOp. 689 //===----------------------------------------------------------------------===// 690 691 LogicalResult verifySymbolAttribute( 692 Operation *op, StringRef attributeName, 693 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)> 694 verifySymbolType) { 695 if (Attribute attribute = op->getAttr(attributeName)) { 696 // The attribute is already verified to be a symbol ref array attribute via 697 // a constraint in the operation definition. 698 for (SymbolRefAttr symbolRef : 699 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) { 700 StringAttr metadataName = symbolRef.getRootReference(); 701 StringAttr symbolName = symbolRef.getLeafReference(); 702 // We want @metadata::@symbol, not just @symbol 703 if (metadataName == symbolName) { 704 return op->emitOpError() << "expected '" << symbolRef 705 << "' to specify a fully qualified reference"; 706 } 707 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 708 op->getParentOp(), metadataName); 709 if (!metadataOp) 710 return op->emitOpError() 711 << "expected '" << symbolRef << "' to reference a metadata op"; 712 Operation *symbolOp = 713 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName); 714 if (!symbolOp) 715 return op->emitOpError() 716 << "expected '" << symbolRef << "' to be a valid reference"; 717 if (failed(verifySymbolType(symbolOp, symbolRef))) { 718 return failure(); 719 } 720 } 721 } 722 return success(); 723 } 724 725 // Verifies that metadata ops are wired up properly. 726 template <typename OpTy> 727 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) { 728 auto verifySymbolType = [op](Operation *symbolOp, 729 SymbolRefAttr symbolRef) -> LogicalResult { 730 if (!isa<OpTy>(symbolOp)) { 731 return op->emitOpError() 732 << "expected '" << symbolRef << "' to resolve to a " 733 << OpTy::getOperationName(); 734 } 735 return success(); 736 }; 737 738 return verifySymbolAttribute(op, attributeName, verifySymbolType); 739 } 740 741 static LogicalResult verifyMemoryOpMetadata(Operation *op) { 742 // access_groups 743 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>( 744 op, LLVMDialect::getAccessGroupsAttrName()))) 745 return failure(); 746 747 // alias_scopes 748 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 749 op, LLVMDialect::getAliasScopesAttrName()))) 750 return failure(); 751 752 // noalias_scopes 753 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>( 754 op, LLVMDialect::getNoAliasScopesAttrName()))) 755 return failure(); 756 757 return success(); 758 } 759 760 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } 761 762 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, 763 Value addr, unsigned alignment, bool isVolatile, 764 bool isNonTemporal) { 765 result.addOperands(addr); 766 result.addTypes(t); 767 if (isVolatile) 768 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 769 if (isNonTemporal) 770 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 771 if (alignment != 0) 772 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 773 } 774 775 void LoadOp::print(OpAsmPrinter &p) { 776 p << ' '; 777 if (getVolatile_()) 778 p << "volatile "; 779 p << getAddr(); 780 p.printOptionalAttrDict((*this)->getAttrs(), 781 {kVolatileAttrName, kElemTypeAttrName}); 782 p << " : " << getAddr().getType(); 783 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 784 p << " -> " << getType(); 785 } 786 787 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 788 // the resulting type if any, null type if opaque pointers are used, and None 789 // if the given type is not the pointer type. 790 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type, 791 SMLoc trailingTypeLoc) { 792 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>(); 793 if (!llvmTy) { 794 parser.emitError(trailingTypeLoc, "expected LLVM pointer type"); 795 return llvm::None; 796 } 797 return llvmTy.getElementType(); 798 } 799 800 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type 801 // (`->` type)? 802 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { 803 OpAsmParser::UnresolvedOperand addr; 804 Type type; 805 SMLoc trailingTypeLoc; 806 807 if (succeeded(parser.parseOptionalKeyword("volatile"))) 808 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 809 810 if (parser.parseOperand(addr) || 811 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 812 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || 813 parser.resolveOperand(addr, type, result.operands)) 814 return failure(); 815 816 Optional<Type> elemTy = 817 getLoadStoreElementType(parser, type, trailingTypeLoc); 818 if (!elemTy) 819 return failure(); 820 if (*elemTy) { 821 result.addTypes(*elemTy); 822 return success(); 823 } 824 825 Type trailingType; 826 if (parser.parseArrow() || parser.parseType(trailingType)) 827 return failure(); 828 result.addTypes(trailingType); 829 return success(); 830 } 831 832 //===----------------------------------------------------------------------===// 833 // Builder, printer and parser for LLVM::StoreOp. 834 //===----------------------------------------------------------------------===// 835 836 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } 837 838 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, 839 Value addr, unsigned alignment, bool isVolatile, 840 bool isNonTemporal) { 841 result.addOperands({value, addr}); 842 result.addTypes({}); 843 if (isVolatile) 844 result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); 845 if (isNonTemporal) 846 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); 847 if (alignment != 0) 848 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); 849 } 850 851 void StoreOp::print(OpAsmPrinter &p) { 852 p << ' '; 853 if (getVolatile_()) 854 p << "volatile "; 855 p << getValue() << ", " << getAddr(); 856 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); 857 p << " : "; 858 if (getAddr().getType().cast<LLVMPointerType>().isOpaque()) 859 p << getValue().getType() << ", "; 860 p << getAddr().getType(); 861 } 862 863 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use 864 // attribute-dict? `:` type (`,` type)? 865 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { 866 OpAsmParser::UnresolvedOperand addr, value; 867 Type type; 868 SMLoc trailingTypeLoc; 869 870 if (succeeded(parser.parseOptionalKeyword("volatile"))) 871 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); 872 873 if (parser.parseOperand(value) || parser.parseComma() || 874 parser.parseOperand(addr) || 875 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 876 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 877 return failure(); 878 879 Type operandType; 880 if (succeeded(parser.parseOptionalComma())) { 881 operandType = type; 882 if (parser.parseType(type)) 883 return failure(); 884 } else { 885 Optional<Type> maybeOperandType = 886 getLoadStoreElementType(parser, type, trailingTypeLoc); 887 if (!maybeOperandType) 888 return failure(); 889 operandType = *maybeOperandType; 890 } 891 892 if (parser.resolveOperand(value, operandType, result.operands) || 893 parser.resolveOperand(addr, type, result.operands)) 894 return failure(); 895 896 return success(); 897 } 898 899 ///===---------------------------------------------------------------------===// 900 /// LLVM::InvokeOp 901 ///===---------------------------------------------------------------------===// 902 903 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { 904 assert(index < getNumSuccessors() && "invalid successor index"); 905 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() 906 : getUnwindDestOperandsMutable()); 907 } 908 909 CallInterfaceCallable InvokeOp::getCallableForCallee() { 910 // Direct call. 911 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 912 return calleeAttr; 913 // Indirect call, callee Value is the first operand. 914 return getOperand(0); 915 } 916 917 Operation::operand_range InvokeOp::getArgOperands() { 918 return getOperands().drop_front(getCallee().has_value() ? 0 : 1); 919 } 920 921 LogicalResult InvokeOp::verify() { 922 if (getNumResults() > 1) 923 return emitOpError("must have 0 or 1 result"); 924 925 Block *unwindDest = getUnwindDest(); 926 if (unwindDest->empty()) 927 return emitError("must have at least one operation in unwind destination"); 928 929 // In unwind destination, first operation must be LandingpadOp 930 if (!isa<LandingpadOp>(unwindDest->front())) 931 return emitError("first operation in unwind destination should be a " 932 "llvm.landingpad operation"); 933 934 return success(); 935 } 936 937 void InvokeOp::print(OpAsmPrinter &p) { 938 auto callee = getCallee(); 939 bool isDirect = callee.has_value(); 940 941 p << ' '; 942 943 // Either function name or pointer 944 if (isDirect) 945 p.printSymbolName(callee.value()); 946 else 947 p << getOperand(0); 948 949 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')'; 950 p << " to "; 951 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); 952 p << " unwind "; 953 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); 954 955 p.printOptionalAttrDict((*this)->getAttrs(), 956 {InvokeOp::getOperandSegmentSizeAttr(), "callee"}); 957 p << " : "; 958 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1), 959 getResultTypes()); 960 } 961 962 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)` 963 /// `to` bb-id (`[` ssa-use-and-type-list `]`)? 964 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)? 965 /// attribute-dict? `:` function-type 966 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { 967 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 968 FunctionType funcType; 969 SymbolRefAttr funcAttr; 970 SMLoc trailingTypeLoc; 971 Block *normalDest, *unwindDest; 972 SmallVector<Value, 4> normalOperands, unwindOperands; 973 Builder &builder = parser.getBuilder(); 974 975 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 976 // case of an indirect call, there will be 1 operand before `(`. In case of a 977 // direct call, there will be no operands and the parser will stop at the 978 // function identifier without complaining. 979 if (parser.parseOperandList(operands)) 980 return failure(); 981 bool isDirect = operands.empty(); 982 983 // Optionally parse a function identifier. 984 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) 985 return failure(); 986 987 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 988 parser.parseKeyword("to") || 989 parser.parseSuccessorAndUseList(normalDest, normalOperands) || 990 parser.parseKeyword("unwind") || 991 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) || 992 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 993 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType)) 994 return failure(); 995 996 if (isDirect) { 997 // Make sure types match. 998 if (parser.resolveOperands(operands, funcType.getInputs(), 999 parser.getNameLoc(), result.operands)) 1000 return failure(); 1001 result.addTypes(funcType.getResults()); 1002 } else { 1003 // Construct the LLVM IR Dialect function type that the first operand 1004 // should match. 1005 if (funcType.getNumResults() > 1) 1006 return parser.emitError(trailingTypeLoc, 1007 "expected function with 0 or 1 result"); 1008 1009 Type llvmResultType; 1010 if (funcType.getNumResults() == 0) { 1011 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1012 } else { 1013 llvmResultType = funcType.getResult(0); 1014 if (!isCompatibleType(llvmResultType)) 1015 return parser.emitError(trailingTypeLoc, 1016 "expected result to have LLVM type"); 1017 } 1018 1019 SmallVector<Type, 8> argTypes; 1020 argTypes.reserve(funcType.getNumInputs()); 1021 for (Type ty : funcType.getInputs()) { 1022 if (isCompatibleType(ty)) 1023 argTypes.push_back(ty); 1024 else 1025 return parser.emitError(trailingTypeLoc, 1026 "expected LLVM types as inputs"); 1027 } 1028 1029 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1030 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1031 1032 auto funcArguments = llvm::makeArrayRef(operands).drop_front(); 1033 1034 // Make sure that the first operand (indirect callee) matches the wrapped 1035 // LLVM IR function type, and that the types of the other call operands 1036 // match the types of the function arguments. 1037 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1038 parser.resolveOperands(funcArguments, funcType.getInputs(), 1039 parser.getNameLoc(), result.operands)) 1040 return failure(); 1041 1042 result.addTypes(llvmResultType); 1043 } 1044 result.addSuccessors({normalDest, unwindDest}); 1045 result.addOperands(normalOperands); 1046 result.addOperands(unwindOperands); 1047 1048 result.addAttribute( 1049 InvokeOp::getOperandSegmentSizeAttr(), 1050 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()), 1051 static_cast<int32_t>(normalOperands.size()), 1052 static_cast<int32_t>(unwindOperands.size())})); 1053 return success(); 1054 } 1055 1056 ///===----------------------------------------------------------------------===// 1057 /// Verifying/Printing/Parsing for LLVM::LandingpadOp. 1058 ///===----------------------------------------------------------------------===// 1059 1060 LogicalResult LandingpadOp::verify() { 1061 Value value; 1062 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { 1063 if (!func.getPersonality()) 1064 return emitError( 1065 "llvm.landingpad needs to be in a function with a personality"); 1066 } 1067 1068 if (!getCleanup() && getOperands().empty()) 1069 return emitError("landingpad instruction expects at least one clause or " 1070 "cleanup attribute"); 1071 1072 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { 1073 value = getOperand(idx); 1074 bool isFilter = value.getType().isa<LLVMArrayType>(); 1075 if (isFilter) { 1076 // FIXME: Verify filter clauses when arrays are appropriately handled 1077 } else { 1078 // catch - global addresses only. 1079 // Bitcast ops should have global addresses as their args. 1080 if (auto bcOp = value.getDefiningOp<BitcastOp>()) { 1081 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) 1082 continue; 1083 return emitError("constant clauses expected").attachNote(bcOp.getLoc()) 1084 << "global addresses expected as operand to " 1085 "bitcast used in clauses for landingpad"; 1086 } 1087 // NullOp and AddressOfOp allowed 1088 if (value.getDefiningOp<NullOp>()) 1089 continue; 1090 if (value.getDefiningOp<AddressOfOp>()) 1091 continue; 1092 return emitError("clause #") 1093 << idx << " is not a known constant - null, addressof, bitcast"; 1094 } 1095 } 1096 return success(); 1097 } 1098 1099 void LandingpadOp::print(OpAsmPrinter &p) { 1100 p << (getCleanup() ? " cleanup " : " "); 1101 1102 // Clauses 1103 for (auto value : getOperands()) { 1104 // Similar to llvm - if clause is an array type then it is filter 1105 // clause else catch clause 1106 bool isArrayTy = value.getType().isa<LLVMArrayType>(); 1107 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " 1108 << value.getType() << ") "; 1109 } 1110 1111 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); 1112 1113 p << ": " << getType(); 1114 } 1115 1116 /// <operation> ::= `llvm.landingpad` `cleanup`? 1117 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? 1118 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { 1119 // Check for cleanup 1120 if (succeeded(parser.parseOptionalKeyword("cleanup"))) 1121 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); 1122 1123 // Parse clauses with types 1124 while (succeeded(parser.parseOptionalLParen()) && 1125 (succeeded(parser.parseOptionalKeyword("filter")) || 1126 succeeded(parser.parseOptionalKeyword("catch")))) { 1127 OpAsmParser::UnresolvedOperand operand; 1128 Type ty; 1129 if (parser.parseOperand(operand) || parser.parseColon() || 1130 parser.parseType(ty) || 1131 parser.resolveOperand(operand, ty, result.operands) || 1132 parser.parseRParen()) 1133 return failure(); 1134 } 1135 1136 Type type; 1137 if (parser.parseColon() || parser.parseType(type)) 1138 return failure(); 1139 1140 result.addTypes(type); 1141 return success(); 1142 } 1143 1144 //===----------------------------------------------------------------------===// 1145 // Verifying/Printing/parsing for LLVM::CallOp. 1146 //===----------------------------------------------------------------------===// 1147 1148 CallInterfaceCallable CallOp::getCallableForCallee() { 1149 // Direct call. 1150 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) 1151 return calleeAttr; 1152 // Indirect call, callee Value is the first operand. 1153 return getOperand(0); 1154 } 1155 1156 Operation::operand_range CallOp::getArgOperands() { 1157 return getOperands().drop_front(getCallee().has_value() ? 0 : 1); 1158 } 1159 1160 LogicalResult CallOp::verify() { 1161 if (getNumResults() > 1) 1162 return emitOpError("must have 0 or 1 result"); 1163 1164 // Type for the callee, we'll get it differently depending if it is a direct 1165 // or indirect call. 1166 Type fnType; 1167 1168 bool isIndirect = false; 1169 1170 // If this is an indirect call, the callee attribute is missing. 1171 FlatSymbolRefAttr calleeName = getCalleeAttr(); 1172 if (!calleeName) { 1173 isIndirect = true; 1174 if (!getNumOperands()) 1175 return emitOpError( 1176 "must have either a `callee` attribute or at least an operand"); 1177 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>(); 1178 if (!ptrType) 1179 return emitOpError("indirect call expects a pointer as callee: ") 1180 << ptrType; 1181 fnType = ptrType.getElementType(); 1182 } else { 1183 Operation *callee = 1184 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); 1185 if (!callee) 1186 return emitOpError() 1187 << "'" << calleeName.getValue() 1188 << "' does not reference a symbol in the current scope"; 1189 auto fn = dyn_cast<LLVMFuncOp>(callee); 1190 if (!fn) 1191 return emitOpError() << "'" << calleeName.getValue() 1192 << "' does not reference a valid LLVM function"; 1193 1194 fnType = fn.getFunctionType(); 1195 } 1196 1197 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>(); 1198 if (!funcType) 1199 return emitOpError("callee does not have a functional type: ") << fnType; 1200 1201 // Verify that the operand and result types match the callee. 1202 1203 if (!funcType.isVarArg() && 1204 funcType.getNumParams() != (getNumOperands() - isIndirect)) 1205 return emitOpError() << "incorrect number of operands (" 1206 << (getNumOperands() - isIndirect) 1207 << ") for callee (expecting: " 1208 << funcType.getNumParams() << ")"; 1209 1210 if (funcType.getNumParams() > (getNumOperands() - isIndirect)) 1211 return emitOpError() << "incorrect number of operands (" 1212 << (getNumOperands() - isIndirect) 1213 << ") for varargs callee (expecting at least: " 1214 << funcType.getNumParams() << ")"; 1215 1216 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) 1217 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) 1218 return emitOpError() << "operand type mismatch for operand " << i << ": " 1219 << getOperand(i + isIndirect).getType() 1220 << " != " << funcType.getParamType(i); 1221 1222 if (getNumResults() == 0 && 1223 !funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1224 return emitOpError() << "expected function call to produce a value"; 1225 1226 if (getNumResults() != 0 && 1227 funcType.getReturnType().isa<LLVM::LLVMVoidType>()) 1228 return emitOpError() 1229 << "calling function with void result must not produce values"; 1230 1231 if (getNumResults() > 1) 1232 return emitOpError() 1233 << "expected LLVM function call to produce 0 or 1 result"; 1234 1235 if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) 1236 return emitOpError() << "result type mismatch: " << getResult(0).getType() 1237 << " != " << funcType.getReturnType(); 1238 1239 return success(); 1240 } 1241 1242 void CallOp::print(OpAsmPrinter &p) { 1243 auto callee = getCallee(); 1244 bool isDirect = callee.has_value(); 1245 1246 // Print the direct callee if present as a function attribute, or an indirect 1247 // callee (first operand) otherwise. 1248 p << ' '; 1249 if (isDirect) 1250 p.printSymbolName(callee.value()); 1251 else 1252 p << getOperand(0); 1253 1254 auto args = getOperands().drop_front(isDirect ? 0 : 1); 1255 p << '(' << args << ')'; 1256 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"}); 1257 1258 // Reconstruct the function MLIR function type from operand and result types. 1259 p << " : "; 1260 p.printFunctionalType(args.getTypes(), getResultTypes()); 1261 } 1262 1263 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 1264 // attribute-dict? `:` function-type 1265 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { 1266 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; 1267 Type type; 1268 SymbolRefAttr funcAttr; 1269 SMLoc trailingTypeLoc; 1270 1271 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 1272 // case of an indirect call, there will be 1 operand before `(`. In case of a 1273 // direct call, there will be no operands and the parser will stop at the 1274 // function identifier without complaining. 1275 if (parser.parseOperandList(operands)) 1276 return failure(); 1277 bool isDirect = operands.empty(); 1278 1279 // Optionally parse a function identifier. 1280 if (isDirect) 1281 if (parser.parseAttribute(funcAttr, "callee", result.attributes)) 1282 return failure(); 1283 1284 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 1285 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1286 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) 1287 return failure(); 1288 1289 auto funcType = type.dyn_cast<FunctionType>(); 1290 if (!funcType) 1291 return parser.emitError(trailingTypeLoc, "expected function type"); 1292 if (funcType.getNumResults() > 1) 1293 return parser.emitError(trailingTypeLoc, 1294 "expected function with 0 or 1 result"); 1295 if (isDirect) { 1296 // Make sure types match. 1297 if (parser.resolveOperands(operands, funcType.getInputs(), 1298 parser.getNameLoc(), result.operands)) 1299 return failure(); 1300 if (funcType.getNumResults() != 0 && 1301 !funcType.getResult(0).isa<LLVM::LLVMVoidType>()) 1302 result.addTypes(funcType.getResults()); 1303 } else { 1304 Builder &builder = parser.getBuilder(); 1305 Type llvmResultType; 1306 if (funcType.getNumResults() == 0) { 1307 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); 1308 } else { 1309 llvmResultType = funcType.getResult(0); 1310 if (!isCompatibleType(llvmResultType)) 1311 return parser.emitError(trailingTypeLoc, 1312 "expected result to have LLVM type"); 1313 } 1314 1315 SmallVector<Type, 8> argTypes; 1316 argTypes.reserve(funcType.getNumInputs()); 1317 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 1318 auto argType = funcType.getInput(i); 1319 if (!isCompatibleType(argType)) 1320 return parser.emitError(trailingTypeLoc, 1321 "expected LLVM types as inputs"); 1322 argTypes.push_back(argType); 1323 } 1324 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); 1325 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); 1326 1327 auto funcArguments = 1328 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front(); 1329 1330 // Make sure that the first operand (indirect callee) matches the wrapped 1331 // LLVM IR function type, and that the types of the other call operands 1332 // match the types of the function arguments. 1333 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || 1334 parser.resolveOperands(funcArguments, funcType.getInputs(), 1335 parser.getNameLoc(), result.operands)) 1336 return failure(); 1337 1338 if (!llvmResultType.isa<LLVM::LLVMVoidType>()) 1339 result.addTypes(llvmResultType); 1340 } 1341 1342 return success(); 1343 } 1344 1345 //===----------------------------------------------------------------------===// 1346 // Printing/parsing for LLVM::ExtractElementOp. 1347 //===----------------------------------------------------------------------===// 1348 // Expects vector to be of wrapped LLVM vector type and position to be of 1349 // wrapped LLVM i32 type. 1350 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, 1351 Value vector, Value position, 1352 ArrayRef<NamedAttribute> attrs) { 1353 auto vectorType = vector.getType(); 1354 auto llvmType = LLVM::getVectorElementType(vectorType); 1355 build(b, result, llvmType, vector, position); 1356 result.addAttributes(attrs); 1357 } 1358 1359 void ExtractElementOp::print(OpAsmPrinter &p) { 1360 p << ' ' << getVector() << "[" << getPosition() << " : " 1361 << getPosition().getType() << "]"; 1362 p.printOptionalAttrDict((*this)->getAttrs()); 1363 p << " : " << getVector().getType(); 1364 } 1365 1366 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 1367 // attribute-dict? `:` type 1368 ParseResult ExtractElementOp::parse(OpAsmParser &parser, 1369 OperationState &result) { 1370 SMLoc loc; 1371 OpAsmParser::UnresolvedOperand vector, position; 1372 Type type, positionType; 1373 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) || 1374 parser.parseLSquare() || parser.parseOperand(position) || 1375 parser.parseColonType(positionType) || parser.parseRSquare() || 1376 parser.parseOptionalAttrDict(result.attributes) || 1377 parser.parseColonType(type) || 1378 parser.resolveOperand(vector, type, result.operands) || 1379 parser.resolveOperand(position, positionType, result.operands)) 1380 return failure(); 1381 if (!LLVM::isCompatibleVectorType(type)) 1382 return parser.emitError( 1383 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1384 result.addTypes(LLVM::getVectorElementType(type)); 1385 return success(); 1386 } 1387 1388 LogicalResult ExtractElementOp::verify() { 1389 Type vectorType = getVector().getType(); 1390 if (!LLVM::isCompatibleVectorType(vectorType)) 1391 return emitOpError("expected LLVM dialect-compatible vector type for " 1392 "operand #1, got") 1393 << vectorType; 1394 Type valueType = LLVM::getVectorElementType(vectorType); 1395 if (valueType != getRes().getType()) 1396 return emitOpError() << "Type mismatch: extracting from " << vectorType 1397 << " should produce " << valueType 1398 << " but this op returns " << getRes().getType(); 1399 return success(); 1400 } 1401 1402 //===----------------------------------------------------------------------===// 1403 // Printing/parsing for LLVM::ExtractValueOp. 1404 //===----------------------------------------------------------------------===// 1405 1406 void ExtractValueOp::print(OpAsmPrinter &p) { 1407 p << ' ' << getContainer() << getPosition(); 1408 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1409 p << " : " << getContainer().getType(); 1410 } 1411 1412 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1413 // `containerType`. Position is an integer array attribute where each value 1414 // is a zero-based position of the element in the aggregate type. Return the 1415 // resulting type wrapped in MLIR, or nullptr on error. 1416 static Type getInsertExtractValueElementType(OpAsmParser &parser, 1417 Type containerType, 1418 ArrayAttr positionAttr, 1419 SMLoc attributeLoc, 1420 SMLoc typeLoc) { 1421 Type llvmType = containerType; 1422 if (!isCompatibleType(containerType)) 1423 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 1424 1425 // Infer the element type from the structure type: iteratively step inside the 1426 // type by taking the element type, indexed by the position attribute for 1427 // structures. Check the position index before accessing, it is supposed to 1428 // be in bounds. 1429 for (Attribute subAttr : positionAttr) { 1430 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1431 if (!positionElementAttr) 1432 return parser.emitError(attributeLoc, 1433 "expected an array of integer literals"), 1434 nullptr; 1435 int position = positionElementAttr.getInt(); 1436 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1437 if (position < 0 || 1438 static_cast<unsigned>(position) >= arrayType.getNumElements()) 1439 return parser.emitError(attributeLoc, "position out of bounds"), 1440 nullptr; 1441 llvmType = arrayType.getElementType(); 1442 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1443 if (position < 0 || 1444 static_cast<unsigned>(position) >= structType.getBody().size()) 1445 return parser.emitError(attributeLoc, "position out of bounds"), 1446 nullptr; 1447 llvmType = structType.getBody()[position]; 1448 } else { 1449 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), 1450 nullptr; 1451 } 1452 } 1453 return llvmType; 1454 } 1455 1456 // Extract the type at `position` in the wrapped LLVM IR aggregate type 1457 // `containerType`. Returns null on failure. 1458 static Type getInsertExtractValueElementType(Type containerType, 1459 ArrayAttr positionAttr, 1460 Operation *op) { 1461 Type llvmType = containerType; 1462 if (!isCompatibleType(containerType)) { 1463 op->emitError("expected LLVM IR Dialect type, got ") << containerType; 1464 return {}; 1465 } 1466 1467 // Infer the element type from the structure type: iteratively step inside the 1468 // type by taking the element type, indexed by the position attribute for 1469 // structures. Check the position index before accessing, it is supposed to 1470 // be in bounds. 1471 for (Attribute subAttr : positionAttr) { 1472 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 1473 if (!positionElementAttr) { 1474 op->emitOpError("expected an array of integer literals, got: ") 1475 << subAttr; 1476 return {}; 1477 } 1478 int position = positionElementAttr.getInt(); 1479 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) { 1480 if (position < 0 || 1481 static_cast<unsigned>(position) >= arrayType.getNumElements()) { 1482 op->emitOpError("position out of bounds: ") << position; 1483 return {}; 1484 } 1485 llvmType = arrayType.getElementType(); 1486 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) { 1487 if (position < 0 || 1488 static_cast<unsigned>(position) >= structType.getBody().size()) { 1489 op->emitOpError("position out of bounds") << position; 1490 return {}; 1491 } 1492 llvmType = structType.getBody()[position]; 1493 } else { 1494 op->emitOpError("expected LLVM IR structure/array type, got: ") 1495 << llvmType; 1496 return {}; 1497 } 1498 } 1499 return llvmType; 1500 } 1501 1502 // <operation> ::= `llvm.extractvalue` ssa-use 1503 // `[` integer-literal (`,` integer-literal)* `]` 1504 // attribute-dict? `:` type 1505 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) { 1506 OpAsmParser::UnresolvedOperand container; 1507 Type containerType; 1508 ArrayAttr positionAttr; 1509 SMLoc attributeLoc, trailingTypeLoc; 1510 1511 if (parser.parseOperand(container) || 1512 parser.getCurrentLocation(&attributeLoc) || 1513 parser.parseAttribute(positionAttr, "position", result.attributes) || 1514 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1515 parser.getCurrentLocation(&trailingTypeLoc) || 1516 parser.parseType(containerType) || 1517 parser.resolveOperand(container, containerType, result.operands)) 1518 return failure(); 1519 1520 auto elementType = getInsertExtractValueElementType( 1521 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1522 if (!elementType) 1523 return failure(); 1524 1525 result.addTypes(elementType); 1526 return success(); 1527 } 1528 1529 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) { 1530 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); 1531 OpFoldResult result = {}; 1532 while (insertValueOp) { 1533 if (getPosition() == insertValueOp.getPosition()) 1534 return insertValueOp.getValue(); 1535 unsigned min = 1536 std::min(getPosition().size(), insertValueOp.getPosition().size()); 1537 // If one is fully prefix of the other, stop propagating back as it will 1538 // miss dependencies. For instance, %3 should not fold to %f0 in the 1539 // following example: 1540 // ``` 1541 // %1 = llvm.insertvalue %f0, %0[0, 0] : 1542 // !llvm.array<4 x !llvm.array<4xf32>> 1543 // %2 = llvm.insertvalue %arr, %1[0] : 1544 // !llvm.array<4 x !llvm.array<4xf32>> 1545 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> 1546 // ``` 1547 if (getPosition().getValue().take_front(min) == 1548 insertValueOp.getPosition().getValue().take_front(min)) 1549 return result; 1550 1551 // If neither a prefix, nor the exact position, we can extract out of the 1552 // value being inserted into. Moreover, we can try again if that operand 1553 // is itself an insertvalue expression. 1554 getContainerMutable().assign(insertValueOp.getContainer()); 1555 result = getResult(); 1556 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); 1557 } 1558 return result; 1559 } 1560 1561 LogicalResult ExtractValueOp::verify() { 1562 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1563 getPositionAttr(), *this); 1564 if (!valueType) 1565 return failure(); 1566 1567 if (getRes().getType() != valueType) 1568 return emitOpError() << "Type mismatch: extracting from " 1569 << getContainer().getType() << " should produce " 1570 << valueType << " but this op returns " 1571 << getRes().getType(); 1572 return success(); 1573 } 1574 1575 //===----------------------------------------------------------------------===// 1576 // Printing/parsing for LLVM::InsertElementOp. 1577 //===----------------------------------------------------------------------===// 1578 1579 void InsertElementOp::print(OpAsmPrinter &p) { 1580 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : " 1581 << getPosition().getType() << "]"; 1582 p.printOptionalAttrDict((*this)->getAttrs()); 1583 p << " : " << getVector().getType(); 1584 } 1585 1586 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 1587 // attribute-dict? `:` type 1588 ParseResult InsertElementOp::parse(OpAsmParser &parser, 1589 OperationState &result) { 1590 SMLoc loc; 1591 OpAsmParser::UnresolvedOperand vector, value, position; 1592 Type vectorType, positionType; 1593 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) || 1594 parser.parseComma() || parser.parseOperand(vector) || 1595 parser.parseLSquare() || parser.parseOperand(position) || 1596 parser.parseColonType(positionType) || parser.parseRSquare() || 1597 parser.parseOptionalAttrDict(result.attributes) || 1598 parser.parseColonType(vectorType)) 1599 return failure(); 1600 1601 if (!LLVM::isCompatibleVectorType(vectorType)) 1602 return parser.emitError( 1603 loc, "expected LLVM dialect-compatible vector type for operand #1"); 1604 Type valueType = LLVM::getVectorElementType(vectorType); 1605 if (!valueType) 1606 return failure(); 1607 1608 if (parser.resolveOperand(vector, vectorType, result.operands) || 1609 parser.resolveOperand(value, valueType, result.operands) || 1610 parser.resolveOperand(position, positionType, result.operands)) 1611 return failure(); 1612 1613 result.addTypes(vectorType); 1614 return success(); 1615 } 1616 1617 LogicalResult InsertElementOp::verify() { 1618 Type valueType = LLVM::getVectorElementType(getVector().getType()); 1619 if (valueType != getValue().getType()) 1620 return emitOpError() << "Type mismatch: cannot insert " 1621 << getValue().getType() << " into " 1622 << getVector().getType(); 1623 return success(); 1624 } 1625 1626 //===----------------------------------------------------------------------===// 1627 // Printing/parsing for LLVM::InsertValueOp. 1628 //===----------------------------------------------------------------------===// 1629 1630 void InsertValueOp::print(OpAsmPrinter &p) { 1631 p << ' ' << getValue() << ", " << getContainer() << getPosition(); 1632 p.printOptionalAttrDict((*this)->getAttrs(), {"position"}); 1633 p << " : " << getContainer().getType(); 1634 } 1635 1636 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 1637 // `[` integer-literal (`,` integer-literal)* `]` 1638 // attribute-dict? `:` type 1639 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) { 1640 OpAsmParser::UnresolvedOperand container, value; 1641 Type containerType; 1642 ArrayAttr positionAttr; 1643 SMLoc attributeLoc, trailingTypeLoc; 1644 1645 if (parser.parseOperand(value) || parser.parseComma() || 1646 parser.parseOperand(container) || 1647 parser.getCurrentLocation(&attributeLoc) || 1648 parser.parseAttribute(positionAttr, "position", result.attributes) || 1649 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || 1650 parser.getCurrentLocation(&trailingTypeLoc) || 1651 parser.parseType(containerType)) 1652 return failure(); 1653 1654 auto valueType = getInsertExtractValueElementType( 1655 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 1656 if (!valueType) 1657 return failure(); 1658 1659 if (parser.resolveOperand(container, containerType, result.operands) || 1660 parser.resolveOperand(value, valueType, result.operands)) 1661 return failure(); 1662 1663 result.addTypes(containerType); 1664 return success(); 1665 } 1666 1667 LogicalResult InsertValueOp::verify() { 1668 Type valueType = getInsertExtractValueElementType(getContainer().getType(), 1669 getPositionAttr(), *this); 1670 if (!valueType) 1671 return failure(); 1672 1673 if (getValue().getType() != valueType) 1674 return emitOpError() << "Type mismatch: cannot insert " 1675 << getValue().getType() << " into " 1676 << getContainer().getType(); 1677 1678 return success(); 1679 } 1680 1681 //===----------------------------------------------------------------------===// 1682 // Printing, parsing and verification for LLVM::ReturnOp. 1683 //===----------------------------------------------------------------------===// 1684 1685 LogicalResult ReturnOp::verify() { 1686 if (getNumOperands() > 1) 1687 return emitOpError("expected at most 1 operand"); 1688 1689 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) { 1690 Type expectedType = parent.getFunctionType().getReturnType(); 1691 if (expectedType.isa<LLVMVoidType>()) { 1692 if (getNumOperands() == 0) 1693 return success(); 1694 InFlightDiagnostic diag = emitOpError("expected no operands"); 1695 diag.attachNote(parent->getLoc()) << "when returning from function"; 1696 return diag; 1697 } 1698 if (getNumOperands() == 0) { 1699 if (expectedType.isa<LLVMVoidType>()) 1700 return success(); 1701 InFlightDiagnostic diag = emitOpError("expected 1 operand"); 1702 diag.attachNote(parent->getLoc()) << "when returning from function"; 1703 return diag; 1704 } 1705 if (expectedType != getOperand(0).getType()) { 1706 InFlightDiagnostic diag = emitOpError("mismatching result types"); 1707 diag.attachNote(parent->getLoc()) << "when returning from function"; 1708 return diag; 1709 } 1710 } 1711 return success(); 1712 } 1713 1714 //===----------------------------------------------------------------------===// 1715 // ResumeOp 1716 //===----------------------------------------------------------------------===// 1717 1718 LogicalResult ResumeOp::verify() { 1719 if (!getValue().getDefiningOp<LandingpadOp>()) 1720 return emitOpError("expects landingpad value as operand"); 1721 // No check for personality of function - landingpad op verifies it. 1722 return success(); 1723 } 1724 1725 //===----------------------------------------------------------------------===// 1726 // Verifier for LLVM::AddressOfOp. 1727 //===----------------------------------------------------------------------===// 1728 1729 template <typename OpTy> 1730 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) { 1731 Operation *module = parent; 1732 while (module && !satisfiesLLVMModule(module)) 1733 module = module->getParentOp(); 1734 assert(module && "unexpected operation outside of a module"); 1735 return dyn_cast_or_null<OpTy>( 1736 mlir::SymbolTable::lookupSymbolIn(module, name)); 1737 } 1738 1739 GlobalOp AddressOfOp::getGlobal() { 1740 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(), 1741 getGlobalName()); 1742 } 1743 1744 LLVMFuncOp AddressOfOp::getFunction() { 1745 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(), 1746 getGlobalName()); 1747 } 1748 1749 LogicalResult AddressOfOp::verify() { 1750 auto global = getGlobal(); 1751 auto function = getFunction(); 1752 if (!global && !function) 1753 return emitOpError( 1754 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); 1755 1756 LLVMPointerType type = getType(); 1757 if (global && global.getAddrSpace() != type.getAddressSpace()) 1758 return emitOpError("pointer address space must match address space of the " 1759 "referenced global"); 1760 1761 if (type.isOpaque()) 1762 return success(); 1763 1764 if (global && type.getElementType() != global.getType()) 1765 return emitOpError( 1766 "the type must be a pointer to the type of the referenced global"); 1767 1768 if (function && type.getElementType() != function.getFunctionType()) 1769 return emitOpError( 1770 "the type must be a pointer to the type of the referenced function"); 1771 1772 return success(); 1773 } 1774 1775 //===----------------------------------------------------------------------===// 1776 // Builder, printer and verifier for LLVM::GlobalOp. 1777 //===----------------------------------------------------------------------===// 1778 1779 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, 1780 bool isConstant, Linkage linkage, StringRef name, 1781 Attribute value, uint64_t alignment, unsigned addrSpace, 1782 bool dsoLocal, bool threadLocal, 1783 ArrayRef<NamedAttribute> attrs) { 1784 result.addAttribute(getSymNameAttrName(result.name), 1785 builder.getStringAttr(name)); 1786 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); 1787 if (isConstant) 1788 result.addAttribute(getConstantAttrName(result.name), 1789 builder.getUnitAttr()); 1790 if (value) 1791 result.addAttribute(getValueAttrName(result.name), value); 1792 if (dsoLocal) 1793 result.addAttribute(getDsoLocalAttrName(result.name), 1794 builder.getUnitAttr()); 1795 if (threadLocal) 1796 result.addAttribute(getThreadLocal_AttrName(result.name), 1797 builder.getUnitAttr()); 1798 1799 // Only add an alignment attribute if the "alignment" input 1800 // is different from 0. The value must also be a power of two, but 1801 // this is tested in GlobalOp::verify, not here. 1802 if (alignment != 0) 1803 result.addAttribute(getAlignmentAttrName(result.name), 1804 builder.getI64IntegerAttr(alignment)); 1805 1806 result.addAttribute(getLinkageAttrName(result.name), 1807 LinkageAttr::get(builder.getContext(), linkage)); 1808 if (addrSpace != 0) 1809 result.addAttribute(getAddrSpaceAttrName(result.name), 1810 builder.getI32IntegerAttr(addrSpace)); 1811 result.attributes.append(attrs.begin(), attrs.end()); 1812 result.addRegion(); 1813 } 1814 1815 void GlobalOp::print(OpAsmPrinter &p) { 1816 p << ' ' << stringifyLinkage(getLinkage()) << ' '; 1817 if (auto unnamedAddr = getUnnamedAddr()) { 1818 StringRef str = stringifyUnnamedAddr(*unnamedAddr); 1819 if (!str.empty()) 1820 p << str << ' '; 1821 } 1822 if (getThreadLocal_()) 1823 p << "thread_local "; 1824 if (getConstant()) 1825 p << "constant "; 1826 p.printSymbolName(getSymName()); 1827 p << '('; 1828 if (auto value = getValueOrNull()) 1829 p.printAttribute(value); 1830 p << ')'; 1831 // Note that the alignment attribute is printed using the 1832 // default syntax here, even though it is an inherent attribute 1833 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) 1834 p.printOptionalAttrDict( 1835 (*this)->getAttrs(), 1836 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(), 1837 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(), 1838 getUnnamedAddrAttrName(), getThreadLocal_AttrName()}); 1839 1840 // Print the trailing type unless it's a string global. 1841 if (getValueOrNull().dyn_cast_or_null<StringAttr>()) 1842 return; 1843 p << " : " << getType(); 1844 1845 Region &initializer = getInitializerRegion(); 1846 if (!initializer.empty()) { 1847 p << ' '; 1848 p.printRegion(initializer, /*printEntryBlockArgs=*/false); 1849 } 1850 } 1851 1852 // Parses one of the keywords provided in the list `keywords` and returns the 1853 // position of the parsed keyword in the list. If none of the keywords from the 1854 // list is parsed, returns -1. 1855 static int parseOptionalKeywordAlternative(OpAsmParser &parser, 1856 ArrayRef<StringRef> keywords) { 1857 for (const auto &en : llvm::enumerate(keywords)) { 1858 if (succeeded(parser.parseOptionalKeyword(en.value()))) 1859 return en.index(); 1860 } 1861 return -1; 1862 } 1863 1864 namespace { 1865 template <typename Ty> 1866 struct EnumTraits {}; 1867 1868 #define REGISTER_ENUM_TYPE(Ty) \ 1869 template <> \ 1870 struct EnumTraits<Ty> { \ 1871 static StringRef stringify(Ty value) { return stringify##Ty(value); } \ 1872 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ 1873 } 1874 1875 REGISTER_ENUM_TYPE(Linkage); 1876 REGISTER_ENUM_TYPE(UnnamedAddr); 1877 REGISTER_ENUM_TYPE(CConv); 1878 } // namespace 1879 1880 /// Parse an enum from the keyword, or default to the provided default value. 1881 /// The return type is the enum type by default, unless overriden with the 1882 /// second template argument. 1883 template <typename EnumTy, typename RetTy = EnumTy> 1884 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, 1885 OperationState &result, 1886 EnumTy defaultValue) { 1887 SmallVector<StringRef, 10> names; 1888 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) 1889 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); 1890 1891 int index = parseOptionalKeywordAlternative(parser, names); 1892 if (index == -1) 1893 return static_cast<RetTy>(defaultValue); 1894 return static_cast<RetTy>(index); 1895 } 1896 1897 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier 1898 // `(` attribute? `)` align? attribute-list? (`:` type)? region? 1899 // align ::= `align` `=` UINT64 1900 // 1901 // The type can be omitted for string attributes, in which case it will be 1902 // inferred from the value of the string as [strlen(value) x i8]. 1903 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1904 MLIRContext *ctx = parser.getContext(); 1905 // Parse optional linkage, default to External. 1906 result.addAttribute(getLinkageAttrName(result.name), 1907 LLVM::LinkageAttr::get( 1908 ctx, parseOptionalLLVMKeyword<Linkage>( 1909 parser, result, LLVM::Linkage::External))); 1910 1911 if (succeeded(parser.parseOptionalKeyword("thread_local"))) 1912 result.addAttribute(getThreadLocal_AttrName(result.name), 1913 parser.getBuilder().getUnitAttr()); 1914 1915 // Parse optional UnnamedAddr, default to None. 1916 result.addAttribute(getUnnamedAddrAttrName(result.name), 1917 parser.getBuilder().getI64IntegerAttr( 1918 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( 1919 parser, result, LLVM::UnnamedAddr::None))); 1920 1921 if (succeeded(parser.parseOptionalKeyword("constant"))) 1922 result.addAttribute(getConstantAttrName(result.name), 1923 parser.getBuilder().getUnitAttr()); 1924 1925 StringAttr name; 1926 if (parser.parseSymbolName(name, getSymNameAttrName(result.name), 1927 result.attributes) || 1928 parser.parseLParen()) 1929 return failure(); 1930 1931 Attribute value; 1932 if (parser.parseOptionalRParen()) { 1933 if (parser.parseAttribute(value, getValueAttrName(result.name), 1934 result.attributes) || 1935 parser.parseRParen()) 1936 return failure(); 1937 } 1938 1939 SmallVector<Type, 1> types; 1940 if (parser.parseOptionalAttrDict(result.attributes) || 1941 parser.parseOptionalColonTypeList(types)) 1942 return failure(); 1943 1944 if (types.size() > 1) 1945 return parser.emitError(parser.getNameLoc(), "expected zero or one type"); 1946 1947 Region &initRegion = *result.addRegion(); 1948 if (types.empty()) { 1949 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) { 1950 MLIRContext *context = parser.getContext(); 1951 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), 1952 strAttr.getValue().size()); 1953 types.push_back(arrayType); 1954 } else { 1955 return parser.emitError(parser.getNameLoc(), 1956 "type can only be omitted for string globals"); 1957 } 1958 } else { 1959 OptionalParseResult parseResult = 1960 parser.parseOptionalRegion(initRegion, /*arguments=*/{}, 1961 /*argTypes=*/{}); 1962 if (parseResult.hasValue() && failed(*parseResult)) 1963 return failure(); 1964 } 1965 1966 result.addAttribute(getGlobalTypeAttrName(result.name), 1967 TypeAttr::get(types[0])); 1968 return success(); 1969 } 1970 1971 static bool isZeroAttribute(Attribute value) { 1972 if (auto intValue = value.dyn_cast<IntegerAttr>()) 1973 return intValue.getValue().isNullValue(); 1974 if (auto fpValue = value.dyn_cast<FloatAttr>()) 1975 return fpValue.getValue().isZero(); 1976 if (auto splatValue = value.dyn_cast<SplatElementsAttr>()) 1977 return isZeroAttribute(splatValue.getSplatValue<Attribute>()); 1978 if (auto elementsValue = value.dyn_cast<ElementsAttr>()) 1979 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); 1980 if (auto arrayValue = value.dyn_cast<ArrayAttr>()) 1981 return llvm::all_of(arrayValue.getValue(), isZeroAttribute); 1982 return false; 1983 } 1984 1985 LogicalResult GlobalOp::verify() { 1986 if (!LLVMPointerType::isValidElementType(getType())) 1987 return emitOpError( 1988 "expects type to be a valid element type for an LLVM pointer"); 1989 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) 1990 return emitOpError("must appear at the module level"); 1991 1992 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) { 1993 auto type = getType().dyn_cast<LLVMArrayType>(); 1994 IntegerType elementType = 1995 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr; 1996 if (!elementType || elementType.getWidth() != 8 || 1997 type.getNumElements() != strAttr.getValue().size()) 1998 return emitOpError( 1999 "requires an i8 array type of the length equal to that of the string " 2000 "attribute"); 2001 } 2002 2003 if (getLinkage() == Linkage::Common) { 2004 if (Attribute value = getValueOrNull()) { 2005 if (!isZeroAttribute(value)) { 2006 return emitOpError() 2007 << "expected zero value for '" 2008 << stringifyLinkage(Linkage::Common) << "' linkage"; 2009 } 2010 } 2011 } 2012 2013 if (getLinkage() == Linkage::Appending) { 2014 if (!getType().isa<LLVMArrayType>()) { 2015 return emitOpError() << "expected array type for '" 2016 << stringifyLinkage(Linkage::Appending) 2017 << "' linkage"; 2018 } 2019 } 2020 2021 Optional<uint64_t> alignAttr = getAlignment(); 2022 if (alignAttr.has_value()) { 2023 uint64_t value = alignAttr.value(); 2024 if (!llvm::isPowerOf2_64(value)) 2025 return emitError() << "alignment attribute is not a power of 2"; 2026 } 2027 2028 return success(); 2029 } 2030 2031 LogicalResult GlobalOp::verifyRegions() { 2032 if (Block *b = getInitializerBlock()) { 2033 ReturnOp ret = cast<ReturnOp>(b->getTerminator()); 2034 if (ret.operand_type_begin() == ret.operand_type_end()) 2035 return emitOpError("initializer region cannot return void"); 2036 if (*ret.operand_type_begin() != getType()) 2037 return emitOpError("initializer region type ") 2038 << *ret.operand_type_begin() << " does not match global type " 2039 << getType(); 2040 2041 for (Operation &op : *b) { 2042 auto iface = dyn_cast<MemoryEffectOpInterface>(op); 2043 if (!iface || !iface.hasNoEffect()) 2044 return op.emitError() 2045 << "ops with side effects not allowed in global initializers"; 2046 } 2047 2048 if (getValueOrNull()) 2049 return emitOpError("cannot have both initializer value and region"); 2050 } 2051 2052 return success(); 2053 } 2054 2055 //===----------------------------------------------------------------------===// 2056 // LLVM::GlobalCtorsOp 2057 //===----------------------------------------------------------------------===// 2058 2059 LogicalResult 2060 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2061 for (Attribute ctor : getCtors()) { 2062 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this, 2063 symbolTable))) 2064 return failure(); 2065 } 2066 return success(); 2067 } 2068 2069 LogicalResult GlobalCtorsOp::verify() { 2070 if (getCtors().size() != getPriorities().size()) 2071 return emitError( 2072 "mismatch between the number of ctors and the number of priorities"); 2073 return success(); 2074 } 2075 2076 //===----------------------------------------------------------------------===// 2077 // LLVM::GlobalDtorsOp 2078 //===----------------------------------------------------------------------===// 2079 2080 LogicalResult 2081 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2082 for (Attribute dtor : getDtors()) { 2083 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this, 2084 symbolTable))) 2085 return failure(); 2086 } 2087 return success(); 2088 } 2089 2090 LogicalResult GlobalDtorsOp::verify() { 2091 if (getDtors().size() != getPriorities().size()) 2092 return emitError( 2093 "mismatch between the number of dtors and the number of priorities"); 2094 return success(); 2095 } 2096 2097 //===----------------------------------------------------------------------===// 2098 // Printing/parsing for LLVM::ShuffleVectorOp. 2099 //===----------------------------------------------------------------------===// 2100 // Expects vector to be of wrapped LLVM vector type and position to be of 2101 // wrapped LLVM i32 type. 2102 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, 2103 Value v1, Value v2, ArrayAttr mask, 2104 ArrayRef<NamedAttribute> attrs) { 2105 auto containerType = v1.getType(); 2106 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType), 2107 mask.size(), 2108 LLVM::isScalableVectorType(containerType)); 2109 build(b, result, vType, v1, v2, mask); 2110 result.addAttributes(attrs); 2111 } 2112 2113 void ShuffleVectorOp::print(OpAsmPrinter &p) { 2114 p << ' ' << getV1() << ", " << getV2() << " " << getMask(); 2115 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"}); 2116 p << " : " << getV1().getType() << ", " << getV2().getType(); 2117 } 2118 2119 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 2120 // `[` integer-literal (`,` integer-literal)* `]` 2121 // attribute-dict? `:` type 2122 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser, 2123 OperationState &result) { 2124 SMLoc loc; 2125 OpAsmParser::UnresolvedOperand v1, v2; 2126 ArrayAttr maskAttr; 2127 Type typeV1, typeV2; 2128 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) || 2129 parser.parseComma() || parser.parseOperand(v2) || 2130 parser.parseAttribute(maskAttr, "mask", result.attributes) || 2131 parser.parseOptionalAttrDict(result.attributes) || 2132 parser.parseColonType(typeV1) || parser.parseComma() || 2133 parser.parseType(typeV2) || 2134 parser.resolveOperand(v1, typeV1, result.operands) || 2135 parser.resolveOperand(v2, typeV2, result.operands)) 2136 return failure(); 2137 if (!LLVM::isCompatibleVectorType(typeV1)) 2138 return parser.emitError( 2139 loc, "expected LLVM IR dialect vector type for operand #1"); 2140 auto vType = 2141 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(), 2142 LLVM::isScalableVectorType(typeV1)); 2143 result.addTypes(vType); 2144 return success(); 2145 } 2146 2147 LogicalResult ShuffleVectorOp::verify() { 2148 Type type1 = getV1().getType(); 2149 Type type2 = getV2().getType(); 2150 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) 2151 return emitOpError("expected matching LLVM IR Dialect element types"); 2152 if (LLVM::isScalableVectorType(type1)) 2153 if (llvm::any_of(getMask(), [](Attribute attr) { 2154 return attr.cast<IntegerAttr>().getInt() != 0; 2155 })) 2156 return emitOpError("expected a splat operation for scalable vectors"); 2157 return success(); 2158 } 2159 2160 //===----------------------------------------------------------------------===// 2161 // Implementations for LLVM::LLVMFuncOp. 2162 //===----------------------------------------------------------------------===// 2163 2164 // Add the entry block to the function. 2165 Block *LLVMFuncOp::addEntryBlock() { 2166 assert(empty() && "function already has an entry block"); 2167 2168 auto *entry = new Block; 2169 push_back(entry); 2170 2171 // FIXME: Allow passing in proper locations for the entry arguments. 2172 LLVMFunctionType type = getFunctionType(); 2173 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) 2174 entry->addArgument(type.getParamType(i), getLoc()); 2175 return entry; 2176 } 2177 2178 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, 2179 StringRef name, Type type, LLVM::Linkage linkage, 2180 bool dsoLocal, CConv cconv, 2181 ArrayRef<NamedAttribute> attrs, 2182 ArrayRef<DictionaryAttr> argAttrs) { 2183 result.addRegion(); 2184 result.addAttribute(SymbolTable::getSymbolAttrName(), 2185 builder.getStringAttr(name)); 2186 result.addAttribute(getFunctionTypeAttrName(result.name), 2187 TypeAttr::get(type)); 2188 result.addAttribute(getLinkageAttrName(result.name), 2189 LinkageAttr::get(builder.getContext(), linkage)); 2190 result.addAttribute(getCConvAttrName(result.name), 2191 CConvAttr::get(builder.getContext(), cconv)); 2192 result.attributes.append(attrs.begin(), attrs.end()); 2193 if (dsoLocal) 2194 result.addAttribute("dso_local", builder.getUnitAttr()); 2195 if (argAttrs.empty()) 2196 return; 2197 2198 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() && 2199 "expected as many argument attribute lists as arguments"); 2200 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, 2201 /*resultAttrs=*/llvm::None); 2202 } 2203 2204 // Builds an LLVM function type from the given lists of input and output types. 2205 // Returns a null type if any of the types provided are non-LLVM types, or if 2206 // there is more than one output type. 2207 static Type 2208 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, 2209 ArrayRef<Type> outputs, 2210 function_interface_impl::VariadicFlag variadicFlag) { 2211 Builder &b = parser.getBuilder(); 2212 if (outputs.size() > 1) { 2213 parser.emitError(loc, "failed to construct function type: expected zero or " 2214 "one function result"); 2215 return {}; 2216 } 2217 2218 // Convert inputs to LLVM types, exit early on error. 2219 SmallVector<Type, 4> llvmInputs; 2220 for (auto t : inputs) { 2221 if (!isCompatibleType(t)) { 2222 parser.emitError(loc, "failed to construct function type: expected LLVM " 2223 "type for function arguments"); 2224 return {}; 2225 } 2226 llvmInputs.push_back(t); 2227 } 2228 2229 // No output is denoted as "void" in LLVM type system. 2230 Type llvmOutput = 2231 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); 2232 if (!isCompatibleType(llvmOutput)) { 2233 parser.emitError(loc, "failed to construct function type: expected LLVM " 2234 "type for function results") 2235 << llvmOutput; 2236 return {}; 2237 } 2238 return LLVMFunctionType::get(llvmOutput, llvmInputs, 2239 variadicFlag.isVariadic()); 2240 } 2241 2242 // Parses an LLVM function. 2243 // 2244 // operation ::= `llvm.func` linkage? cconv? function-signature 2245 // function-attributes? 2246 // function-body 2247 // 2248 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { 2249 // Default to external linkage if no keyword is provided. 2250 result.addAttribute( 2251 getLinkageAttrName(result.name), 2252 LinkageAttr::get(parser.getContext(), 2253 parseOptionalLLVMKeyword<Linkage>( 2254 parser, result, LLVM::Linkage::External))); 2255 2256 // Default to C Calling Convention if no keyword is provided. 2257 result.addAttribute( 2258 getCConvAttrName(result.name), 2259 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( 2260 parser, result, LLVM::CConv::C))); 2261 2262 StringAttr nameAttr; 2263 SmallVector<OpAsmParser::Argument> entryArgs; 2264 SmallVector<DictionaryAttr> resultAttrs; 2265 SmallVector<Type> resultTypes; 2266 bool isVariadic; 2267 2268 auto signatureLocation = parser.getCurrentLocation(); 2269 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 2270 result.attributes) || 2271 function_interface_impl::parseFunctionSignature( 2272 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, 2273 resultAttrs)) 2274 return failure(); 2275 2276 SmallVector<Type> argTypes; 2277 for (auto &arg : entryArgs) 2278 argTypes.push_back(arg.type); 2279 auto type = 2280 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, 2281 function_interface_impl::VariadicFlag(isVariadic)); 2282 if (!type) 2283 return failure(); 2284 result.addAttribute(FunctionOpInterface::getTypeAttrName(), 2285 TypeAttr::get(type)); 2286 2287 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 2288 return failure(); 2289 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, 2290 entryArgs, resultAttrs); 2291 2292 auto *body = result.addRegion(); 2293 OptionalParseResult parseResult = 2294 parser.parseOptionalRegion(*body, entryArgs); 2295 return failure(parseResult.hasValue() && failed(*parseResult)); 2296 } 2297 2298 // Print the LLVMFuncOp. Collects argument and result types and passes them to 2299 // helper functions. Drops "void" result since it cannot be parsed back. Skips 2300 // the external linkage since it is the default value. 2301 void LLVMFuncOp::print(OpAsmPrinter &p) { 2302 p << ' '; 2303 if (getLinkage() != LLVM::Linkage::External) 2304 p << stringifyLinkage(getLinkage()) << ' '; 2305 if (getCConv() != LLVM::CConv::C) 2306 p << stringifyCConv(getCConv()) << ' '; 2307 2308 p.printSymbolName(getName()); 2309 2310 LLVMFunctionType fnType = getFunctionType(); 2311 SmallVector<Type, 8> argTypes; 2312 SmallVector<Type, 1> resTypes; 2313 argTypes.reserve(fnType.getNumParams()); 2314 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) 2315 argTypes.push_back(fnType.getParamType(i)); 2316 2317 Type returnType = fnType.getReturnType(); 2318 if (!returnType.isa<LLVMVoidType>()) 2319 resTypes.push_back(returnType); 2320 2321 function_interface_impl::printFunctionSignature(p, *this, argTypes, 2322 isVarArg(), resTypes); 2323 function_interface_impl::printFunctionAttributes( 2324 p, *this, argTypes.size(), resTypes.size(), 2325 {getLinkageAttrName(), getCConvAttrName()}); 2326 2327 // Print the body if this is not an external function. 2328 Region &body = getBody(); 2329 if (!body.empty()) { 2330 p << ' '; 2331 p.printRegion(body, /*printEntryBlockArgs=*/false, 2332 /*printBlockTerminators=*/true); 2333 } 2334 } 2335 2336 // Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2337 // - functions don't have 'common' linkage 2338 // - external functions have 'external' or 'extern_weak' linkage; 2339 // - vararg is (currently) only supported for external functions; 2340 LogicalResult LLVMFuncOp::verify() { 2341 if (getLinkage() == LLVM::Linkage::Common) 2342 return emitOpError() << "functions cannot have '" 2343 << stringifyLinkage(LLVM::Linkage::Common) 2344 << "' linkage"; 2345 2346 // Check to see if this function has a void return with a result attribute to 2347 // it. It isn't clear what semantics we would assign to that. 2348 if (getFunctionType().getReturnType().isa<LLVMVoidType>() && 2349 !getResultAttrs(0).empty()) { 2350 return emitOpError() 2351 << "cannot attach result attributes to functions with a void return"; 2352 } 2353 2354 if (isExternal()) { 2355 if (getLinkage() != LLVM::Linkage::External && 2356 getLinkage() != LLVM::Linkage::ExternWeak) 2357 return emitOpError() << "external functions must have '" 2358 << stringifyLinkage(LLVM::Linkage::External) 2359 << "' or '" 2360 << stringifyLinkage(LLVM::Linkage::ExternWeak) 2361 << "' linkage"; 2362 return success(); 2363 } 2364 2365 return success(); 2366 } 2367 2368 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: 2369 /// - entry block arguments are of LLVM types. 2370 LogicalResult LLVMFuncOp::verifyRegions() { 2371 if (isExternal()) 2372 return success(); 2373 2374 unsigned numArguments = getFunctionType().getNumParams(); 2375 Block &entryBlock = front(); 2376 for (unsigned i = 0; i < numArguments; ++i) { 2377 Type argType = entryBlock.getArgument(i).getType(); 2378 if (!isCompatibleType(argType)) 2379 return emitOpError("entry block argument #") 2380 << i << " is not of LLVM type"; 2381 } 2382 2383 return success(); 2384 } 2385 2386 //===----------------------------------------------------------------------===// 2387 // Verification for LLVM::ConstantOp. 2388 //===----------------------------------------------------------------------===// 2389 2390 LogicalResult LLVM::ConstantOp::verify() { 2391 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) { 2392 auto arrayType = getType().dyn_cast<LLVMArrayType>(); 2393 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || 2394 !arrayType.getElementType().isInteger(8)) { 2395 return emitOpError() << "expected array type of " 2396 << sAttr.getValue().size() 2397 << " i8 elements for the string constant"; 2398 } 2399 return success(); 2400 } 2401 if (auto structType = getType().dyn_cast<LLVMStructType>()) { 2402 if (structType.getBody().size() != 2 || 2403 structType.getBody()[0] != structType.getBody()[1]) { 2404 return emitError() << "expected struct type with two elements of the " 2405 "same type, the type of a complex constant"; 2406 } 2407 2408 auto arrayAttr = getValue().dyn_cast<ArrayAttr>(); 2409 if (!arrayAttr || arrayAttr.size() != 2 || 2410 arrayAttr[0].getType() != arrayAttr[1].getType()) { 2411 return emitOpError() << "expected array attribute with two elements, " 2412 "representing a complex constant"; 2413 } 2414 2415 Type elementType = structType.getBody()[0]; 2416 if (!elementType 2417 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) { 2418 return emitError() 2419 << "expected struct element types to be floating point type or " 2420 "integer type"; 2421 } 2422 return success(); 2423 } 2424 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>()) 2425 return emitOpError() 2426 << "only supports integer, float, string or elements attributes"; 2427 return success(); 2428 } 2429 2430 // Constant op constant-folds to its value. 2431 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); } 2432 2433 //===----------------------------------------------------------------------===// 2434 // Utility functions for parsing atomic ops 2435 //===----------------------------------------------------------------------===// 2436 2437 // Helper function to parse a keyword into the specified attribute named by 2438 // `attrName`. The keyword must match one of the string values defined by the 2439 // AtomicBinOp enum. The resulting I64 attribute is added to the `result` 2440 // state. 2441 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result, 2442 StringRef attrName) { 2443 SMLoc loc; 2444 StringRef keyword; 2445 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword)) 2446 return failure(); 2447 2448 // Replace the keyword `keyword` with an integer attribute. 2449 auto kind = symbolizeAtomicBinOp(keyword); 2450 if (!kind) { 2451 return parser.emitError(loc) 2452 << "'" << keyword << "' is an incorrect value of the '" << attrName 2453 << "' attribute"; 2454 } 2455 2456 auto value = static_cast<int64_t>(*kind); 2457 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2458 result.addAttribute(attrName, attr); 2459 2460 return success(); 2461 } 2462 2463 // Helper function to parse a keyword into the specified attribute named by 2464 // `attrName`. The keyword must match one of the string values defined by the 2465 // AtomicOrdering enum. The resulting I64 attribute is added to the `result` 2466 // state. 2467 static ParseResult parseAtomicOrdering(OpAsmParser &parser, 2468 OperationState &result, 2469 StringRef attrName) { 2470 SMLoc loc; 2471 StringRef ordering; 2472 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering)) 2473 return failure(); 2474 2475 // Replace the keyword `ordering` with an integer attribute. 2476 auto kind = symbolizeAtomicOrdering(ordering); 2477 if (!kind) { 2478 return parser.emitError(loc) 2479 << "'" << ordering << "' is an incorrect value of the '" << attrName 2480 << "' attribute"; 2481 } 2482 2483 auto value = static_cast<int64_t>(*kind); 2484 auto attr = parser.getBuilder().getI64IntegerAttr(value); 2485 result.addAttribute(attrName, attr); 2486 2487 return success(); 2488 } 2489 2490 //===----------------------------------------------------------------------===// 2491 // Printer, parser and verifier for LLVM::AtomicRMWOp. 2492 //===----------------------------------------------------------------------===// 2493 2494 void AtomicRMWOp::print(OpAsmPrinter &p) { 2495 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", " 2496 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' '; 2497 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"}); 2498 p << " : " << getRes().getType(); 2499 } 2500 2501 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword 2502 // attribute-dict? `:` type 2503 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { 2504 Type type; 2505 OpAsmParser::UnresolvedOperand ptr, val; 2506 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || 2507 parser.parseComma() || parser.parseOperand(val) || 2508 parseAtomicOrdering(parser, result, "ordering") || 2509 parser.parseOptionalAttrDict(result.attributes) || 2510 parser.parseColonType(type) || 2511 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2512 result.operands) || 2513 parser.resolveOperand(val, type, result.operands)) 2514 return failure(); 2515 2516 result.addTypes(type); 2517 return success(); 2518 } 2519 2520 LogicalResult AtomicRMWOp::verify() { 2521 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2522 auto valType = getVal().getType(); 2523 if (valType != ptrType.getElementType()) 2524 return emitOpError("expected LLVM IR element type for operand #0 to " 2525 "match type for operand #1"); 2526 auto resType = getRes().getType(); 2527 if (resType != valType) 2528 return emitOpError( 2529 "expected LLVM IR result type to match type for operand #1"); 2530 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { 2531 if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) 2532 return emitOpError("expected LLVM IR floating point type"); 2533 } else if (getBinOp() == AtomicBinOp::xchg) { 2534 auto intType = valType.dyn_cast<IntegerType>(); 2535 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2536 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2537 intBitWidth != 64 && !valType.isa<BFloat16Type>() && 2538 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() && 2539 !valType.isa<Float64Type>()) 2540 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); 2541 } else { 2542 auto intType = valType.dyn_cast<IntegerType>(); 2543 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2544 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && 2545 intBitWidth != 64) 2546 return emitOpError("expected LLVM IR integer type"); 2547 } 2548 2549 if (static_cast<unsigned>(getOrdering()) < 2550 static_cast<unsigned>(AtomicOrdering::monotonic)) 2551 return emitOpError() << "expected at least '" 2552 << stringifyAtomicOrdering(AtomicOrdering::monotonic) 2553 << "' ordering"; 2554 2555 return success(); 2556 } 2557 2558 //===----------------------------------------------------------------------===// 2559 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp. 2560 //===----------------------------------------------------------------------===// 2561 2562 void AtomicCmpXchgOp::print(OpAsmPrinter &p) { 2563 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' ' 2564 << stringifyAtomicOrdering(getSuccessOrdering()) << ' ' 2565 << stringifyAtomicOrdering(getFailureOrdering()); 2566 p.printOptionalAttrDict((*this)->getAttrs(), 2567 {"success_ordering", "failure_ordering"}); 2568 p << " : " << getVal().getType(); 2569 } 2570 2571 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use 2572 // keyword keyword attribute-dict? `:` type 2573 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser, 2574 OperationState &result) { 2575 auto &builder = parser.getBuilder(); 2576 Type type; 2577 OpAsmParser::UnresolvedOperand ptr, cmp, val; 2578 if (parser.parseOperand(ptr) || parser.parseComma() || 2579 parser.parseOperand(cmp) || parser.parseComma() || 2580 parser.parseOperand(val) || 2581 parseAtomicOrdering(parser, result, "success_ordering") || 2582 parseAtomicOrdering(parser, result, "failure_ordering") || 2583 parser.parseOptionalAttrDict(result.attributes) || 2584 parser.parseColonType(type) || 2585 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), 2586 result.operands) || 2587 parser.resolveOperand(cmp, type, result.operands) || 2588 parser.resolveOperand(val, type, result.operands)) 2589 return failure(); 2590 2591 auto boolType = IntegerType::get(builder.getContext(), 1); 2592 auto resultType = 2593 LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); 2594 result.addTypes(resultType); 2595 2596 return success(); 2597 } 2598 2599 LogicalResult AtomicCmpXchgOp::verify() { 2600 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>(); 2601 if (!ptrType) 2602 return emitOpError("expected LLVM IR pointer type for operand #0"); 2603 auto cmpType = getCmp().getType(); 2604 auto valType = getVal().getType(); 2605 if (cmpType != ptrType.getElementType() || cmpType != valType) 2606 return emitOpError("expected LLVM IR element type for operand #0 to " 2607 "match type for all other operands"); 2608 auto intType = valType.dyn_cast<IntegerType>(); 2609 unsigned intBitWidth = intType ? intType.getWidth() : 0; 2610 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 && 2611 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && 2612 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() && 2613 !valType.isa<Float32Type>() && !valType.isa<Float64Type>()) 2614 return emitOpError("unexpected LLVM IR type"); 2615 if (getSuccessOrdering() < AtomicOrdering::monotonic || 2616 getFailureOrdering() < AtomicOrdering::monotonic) 2617 return emitOpError("ordering must be at least 'monotonic'"); 2618 if (getFailureOrdering() == AtomicOrdering::release || 2619 getFailureOrdering() == AtomicOrdering::acq_rel) 2620 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); 2621 return success(); 2622 } 2623 2624 //===----------------------------------------------------------------------===// 2625 // Printer, parser and verifier for LLVM::FenceOp. 2626 //===----------------------------------------------------------------------===// 2627 2628 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword 2629 // attribute-dict? 2630 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) { 2631 StringAttr sScope; 2632 StringRef syncscopeKeyword = "syncscope"; 2633 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) { 2634 if (parser.parseLParen() || 2635 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) || 2636 parser.parseRParen()) 2637 return failure(); 2638 } else { 2639 result.addAttribute(syncscopeKeyword, 2640 parser.getBuilder().getStringAttr("")); 2641 } 2642 if (parseAtomicOrdering(parser, result, "ordering") || 2643 parser.parseOptionalAttrDict(result.attributes)) 2644 return failure(); 2645 return success(); 2646 } 2647 2648 void FenceOp::print(OpAsmPrinter &p) { 2649 StringRef syncscopeKeyword = "syncscope"; 2650 p << ' '; 2651 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty()) 2652 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") "; 2653 p << stringifyAtomicOrdering(getOrdering()); 2654 } 2655 2656 LogicalResult FenceOp::verify() { 2657 if (getOrdering() == AtomicOrdering::not_atomic || 2658 getOrdering() == AtomicOrdering::unordered || 2659 getOrdering() == AtomicOrdering::monotonic) 2660 return emitOpError("can be given only acquire, release, acq_rel, " 2661 "and seq_cst orderings"); 2662 return success(); 2663 } 2664 2665 //===----------------------------------------------------------------------===// 2666 // Folder for LLVM::BitcastOp 2667 //===----------------------------------------------------------------------===// 2668 2669 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) { 2670 // bitcast(x : T0, T0) -> x 2671 if (getArg().getType() == getType()) 2672 return getArg(); 2673 // bitcast(bitcast(x : T0, T1), T0) -> x 2674 if (auto prev = getArg().getDefiningOp<BitcastOp>()) 2675 if (prev.getArg().getType() == getType()) 2676 return prev.getArg(); 2677 return {}; 2678 } 2679 2680 //===----------------------------------------------------------------------===// 2681 // Folder for LLVM::AddrSpaceCastOp 2682 //===----------------------------------------------------------------------===// 2683 2684 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) { 2685 // addrcast(x : T0, T0) -> x 2686 if (getArg().getType() == getType()) 2687 return getArg(); 2688 // addrcast(addrcast(x : T0, T1), T0) -> x 2689 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>()) 2690 if (prev.getArg().getType() == getType()) 2691 return prev.getArg(); 2692 return {}; 2693 } 2694 2695 //===----------------------------------------------------------------------===// 2696 // Folder for LLVM::GEPOp 2697 //===----------------------------------------------------------------------===// 2698 2699 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) { 2700 // gep %x:T, 0 -> %x 2701 if (getBase().getType() == getType() && getIndices().size() == 1 && 2702 matchPattern(getIndices()[0], m_Zero())) 2703 return getBase(); 2704 return {}; 2705 } 2706 2707 //===----------------------------------------------------------------------===// 2708 // LLVMDialect initialization, type parsing, and registration. 2709 //===----------------------------------------------------------------------===// 2710 2711 void LLVMDialect::initialize() { 2712 addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>(); 2713 2714 // clang-format off 2715 addTypes<LLVMVoidType, 2716 LLVMPPCFP128Type, 2717 LLVMX86MMXType, 2718 LLVMTokenType, 2719 LLVMLabelType, 2720 LLVMMetadataType, 2721 LLVMFunctionType, 2722 LLVMPointerType, 2723 LLVMFixedVectorType, 2724 LLVMScalableVectorType, 2725 LLVMArrayType, 2726 LLVMStructType>(); 2727 // clang-format on 2728 addOperations< 2729 #define GET_OP_LIST 2730 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2731 , 2732 #define GET_OP_LIST 2733 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" 2734 >(); 2735 2736 // Support unknown operations because not all LLVM operations are registered. 2737 allowUnknownOperations(); 2738 } 2739 2740 #define GET_OP_CLASSES 2741 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 2742 2743 /// Parse a type registered to this dialect. 2744 Type LLVMDialect::parseType(DialectAsmParser &parser) const { 2745 return detail::parseType(parser); 2746 } 2747 2748 /// Print a type registered to this dialect. 2749 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { 2750 return detail::printType(type, os); 2751 } 2752 2753 LogicalResult LLVMDialect::verifyDataLayoutString( 2754 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { 2755 llvm::Expected<llvm::DataLayout> maybeDataLayout = 2756 llvm::DataLayout::parse(descr); 2757 if (maybeDataLayout) 2758 return success(); 2759 2760 std::string message; 2761 llvm::raw_string_ostream messageStream(message); 2762 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); 2763 reportError("invalid data layout descriptor: " + messageStream.str()); 2764 return failure(); 2765 } 2766 2767 /// Verify LLVM dialect attributes. 2768 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, 2769 NamedAttribute attr) { 2770 // If the `llvm.loop` attribute is present, enforce the following structure, 2771 // which the module translation can assume. 2772 if (attr.getName() == LLVMDialect::getLoopAttrName()) { 2773 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>(); 2774 if (!loopAttr) 2775 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() 2776 << "' to be a dictionary attribute"; 2777 Optional<NamedAttribute> parallelAccessGroup = 2778 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); 2779 if (parallelAccessGroup) { 2780 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>(); 2781 if (!accessGroups) 2782 return op->emitOpError() 2783 << "expected '" << LLVMDialect::getParallelAccessAttrName() 2784 << "' to be an array attribute"; 2785 for (Attribute attr : accessGroups) { 2786 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>(); 2787 if (!accessGroupRef) 2788 return op->emitOpError() 2789 << "expected '" << attr << "' to be a symbol reference"; 2790 StringAttr metadataName = accessGroupRef.getRootReference(); 2791 auto metadataOp = 2792 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>( 2793 op->getParentOp(), metadataName); 2794 if (!metadataOp) 2795 return op->emitOpError() 2796 << "expected '" << attr << "' to reference a metadata op"; 2797 StringAttr accessGroupName = accessGroupRef.getLeafReference(); 2798 Operation *accessGroupOp = 2799 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); 2800 if (!accessGroupOp) 2801 return op->emitOpError() 2802 << "expected '" << attr << "' to reference an access_group op"; 2803 } 2804 } 2805 2806 Optional<NamedAttribute> loopOptions = 2807 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); 2808 if (loopOptions && !loopOptions->getValue().isa<LoopOptionsAttr>()) 2809 return op->emitOpError() 2810 << "expected '" << LLVMDialect::getLoopOptionsAttrName() 2811 << "' to be a `loopopts` attribute"; 2812 } 2813 2814 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2815 return op->emitOpError() 2816 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName() 2817 << "' is permitted only in argument or result attributes"; 2818 } 2819 2820 // If the data layout attribute is present, it must use the LLVM data layout 2821 // syntax. Try parsing it and report errors in case of failure. Users of this 2822 // attribute may assume it is well-formed and can pass it to the (asserting) 2823 // llvm::DataLayout constructor. 2824 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) 2825 return success(); 2826 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>()) 2827 return verifyDataLayoutString( 2828 stringAttr.getValue(), 2829 [op](const Twine &message) { op->emitOpError() << message.str(); }); 2830 2831 return op->emitOpError() << "expected '" 2832 << LLVM::LLVMDialect::getDataLayoutAttrName() 2833 << "' to be a string attributes"; 2834 } 2835 2836 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr, 2837 Type annotatedType) { 2838 auto structType = annotatedType.dyn_cast<LLVMStructType>(); 2839 if (!structType) { 2840 const auto emitIncorrectAnnotatedType = [&op]() { 2841 return op->emitError() 2842 << "expected '" << LLVMDialect::getStructAttrsAttrName() 2843 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'"; 2844 }; 2845 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>(); 2846 if (!ptrType) 2847 return emitIncorrectAnnotatedType(); 2848 structType = ptrType.getElementType().dyn_cast<LLVMStructType>(); 2849 if (!structType) 2850 return emitIncorrectAnnotatedType(); 2851 } 2852 2853 const auto arrAttrs = attr.dyn_cast<ArrayAttr>(); 2854 if (!arrAttrs) 2855 return op->emitError() << "expected '" 2856 << LLVMDialect::getStructAttrsAttrName() 2857 << "' to be an array attribute"; 2858 2859 if (structType.getBody().size() != arrAttrs.size()) 2860 return op->emitError() 2861 << "size of '" << LLVMDialect::getStructAttrsAttrName() 2862 << "' must match the size of the annotated '!llvm.struct'"; 2863 return success(); 2864 } 2865 2866 static LogicalResult verifyFuncOpInterfaceStructAttr( 2867 Operation *op, Attribute attr, 2868 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) { 2869 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) 2870 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp)); 2871 return op->emitError() << "expected '" 2872 << LLVMDialect::getStructAttrsAttrName() 2873 << "' to be used on function-like operations"; 2874 } 2875 2876 /// Verify LLVMIR function argument attributes. 2877 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 2878 unsigned regionIdx, 2879 unsigned argIdx, 2880 NamedAttribute argAttr) { 2881 // Check that llvm.noalias is a unit attribute. 2882 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && 2883 !argAttr.getValue().isa<UnitAttr>()) 2884 return op->emitError() 2885 << "expected llvm.noalias argument attribute to be a unit attribute"; 2886 // Check that llvm.align is an integer attribute. 2887 if (argAttr.getName() == LLVMDialect::getAlignAttrName() && 2888 !argAttr.getValue().isa<IntegerAttr>()) 2889 return op->emitError() 2890 << "llvm.align argument attribute of non integer type"; 2891 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2892 return verifyFuncOpInterfaceStructAttr( 2893 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) { 2894 return funcOp.getArgumentTypes()[argIdx]; 2895 }); 2896 } 2897 return success(); 2898 } 2899 2900 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, 2901 unsigned regionIdx, 2902 unsigned resIdx, 2903 NamedAttribute resAttr) { 2904 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) { 2905 return verifyFuncOpInterfaceStructAttr( 2906 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) { 2907 return funcOp.getResultTypes()[resIdx]; 2908 }); 2909 } 2910 return success(); 2911 } 2912 2913 //===----------------------------------------------------------------------===// 2914 // Utility functions. 2915 //===----------------------------------------------------------------------===// 2916 2917 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 2918 StringRef name, StringRef value, 2919 LLVM::Linkage linkage) { 2920 assert(builder.getInsertionBlock() && 2921 builder.getInsertionBlock()->getParentOp() && 2922 "expected builder to point to a block constrained in an op"); 2923 auto module = 2924 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 2925 assert(module && "builder points to an op outside of a module"); 2926 2927 // Create the global at the entry of the module. 2928 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); 2929 MLIRContext *ctx = builder.getContext(); 2930 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); 2931 auto global = moduleBuilder.create<LLVM::GlobalOp>( 2932 loc, type, /*isConstant=*/true, linkage, name, 2933 builder.getStringAttr(value), /*alignment=*/0); 2934 2935 // Get the pointer to the first character in the global string. 2936 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 2937 Value cst0 = builder.create<LLVM::ConstantOp>( 2938 loc, IntegerType::get(ctx, 64), 2939 builder.getIntegerAttr(builder.getIndexType(), 0)); 2940 return builder.create<LLVM::GEPOp>( 2941 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr, 2942 ValueRange{cst0, cst0}); 2943 } 2944 2945 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { 2946 return op->hasTrait<OpTrait::SymbolTable>() && 2947 op->hasTrait<OpTrait::IsIsolatedFromAbove>(); 2948 } 2949 2950 void FMFAttr::print(AsmPrinter &printer) const { 2951 printer << "<"; 2952 printer << stringifyFastmathFlags(this->getFlags()); 2953 printer << ">"; 2954 } 2955 2956 Attribute FMFAttr::parse(AsmParser &parser, Type type) { 2957 if (failed(parser.parseLess())) 2958 return {}; 2959 2960 FastmathFlags flags = {}; 2961 if (failed(parser.parseOptionalGreater())) { 2962 auto parseFlags = [&]() -> ParseResult { 2963 StringRef elemName; 2964 if (failed(parser.parseKeyword(&elemName))) 2965 return failure(); 2966 2967 auto elem = symbolizeFastmathFlags(elemName); 2968 if (!elem) 2969 return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ") 2970 << elemName; 2971 2972 flags = flags | *elem; 2973 return success(); 2974 }; 2975 if (failed(parser.parseCommaSeparatedList(parseFlags)) || 2976 parser.parseGreater()) 2977 return {}; 2978 } 2979 2980 return FMFAttr::get(parser.getContext(), flags); 2981 } 2982 2983 void LinkageAttr::print(AsmPrinter &printer) const { 2984 printer << "<"; 2985 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage()) 2986 printer << stringifyEnum(getLinkage()); 2987 else 2988 printer << static_cast<uint64_t>(getLinkage()); 2989 printer << ">"; 2990 } 2991 2992 Attribute LinkageAttr::parse(AsmParser &parser, Type type) { 2993 StringRef elemName; 2994 if (parser.parseLess() || parser.parseKeyword(&elemName) || 2995 parser.parseGreater()) 2996 return {}; 2997 auto elem = linkage::symbolizeLinkage(elemName); 2998 if (!elem) { 2999 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; 3000 return {}; 3001 } 3002 Linkage linkage = *elem; 3003 return LinkageAttr::get(parser.getContext(), linkage); 3004 } 3005 3006 void CConvAttr::print(AsmPrinter &printer) const { 3007 printer << "<"; 3008 if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv()) 3009 printer << stringifyEnum(getCallingConv()); 3010 else 3011 printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv()); 3012 printer << ">"; 3013 } 3014 3015 Attribute CConvAttr::parse(AsmParser &parser, Type type) { 3016 StringRef convName; 3017 3018 if (parser.parseLess() || parser.parseKeyword(&convName) || 3019 parser.parseGreater()) 3020 return {}; 3021 auto cconv = cconv::symbolizeCConv(convName); 3022 if (!cconv) { 3023 parser.emitError(parser.getNameLoc(), "unknown calling convention: ") 3024 << convName; 3025 return {}; 3026 } 3027 CConv cconvVal = *cconv; 3028 return CConvAttr::get(parser.getContext(), cconvVal); 3029 } 3030 3031 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) 3032 : options(attr.getOptions().begin(), attr.getOptions().end()) {} 3033 3034 template <typename T> 3035 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag, 3036 Optional<T> value) { 3037 auto option = llvm::find_if( 3038 options, [tag](auto option) { return option.first == tag; }); 3039 if (option != options.end()) { 3040 if (value) 3041 option->second = *value; 3042 else 3043 options.erase(option); 3044 } else { 3045 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value)); 3046 } 3047 return *this; 3048 } 3049 3050 LoopOptionsAttrBuilder & 3051 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) { 3052 return setOption(LoopOptionCase::disable_licm, value); 3053 } 3054 3055 /// Set the `interleave_count` option to the provided value. If no value 3056 /// is provided the option is deleted. 3057 LoopOptionsAttrBuilder & 3058 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) { 3059 return setOption(LoopOptionCase::interleave_count, count); 3060 } 3061 3062 /// Set the `disable_unroll` option to the provided value. If no value 3063 /// is provided the option is deleted. 3064 LoopOptionsAttrBuilder & 3065 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) { 3066 return setOption(LoopOptionCase::disable_unroll, value); 3067 } 3068 3069 /// Set the `disable_pipeline` option to the provided value. If no value 3070 /// is provided the option is deleted. 3071 LoopOptionsAttrBuilder & 3072 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) { 3073 return setOption(LoopOptionCase::disable_pipeline, value); 3074 } 3075 3076 /// Set the `pipeline_initiation_interval` option to the provided value. 3077 /// If no value is provided the option is deleted. 3078 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval( 3079 Optional<uint64_t> count) { 3080 return setOption(LoopOptionCase::pipeline_initiation_interval, count); 3081 } 3082 3083 template <typename T> 3084 static Optional<T> 3085 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options, 3086 LoopOptionCase option) { 3087 auto it = 3088 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) { 3089 return optionPair.first < option; 3090 }); 3091 if (it == options.end()) 3092 return {}; 3093 return static_cast<T>(it->second); 3094 } 3095 3096 Optional<bool> LoopOptionsAttr::disableUnroll() { 3097 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll); 3098 } 3099 3100 Optional<bool> LoopOptionsAttr::disableLICM() { 3101 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm); 3102 } 3103 3104 Optional<int64_t> LoopOptionsAttr::interleaveCount() { 3105 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count); 3106 } 3107 3108 /// Build the LoopOptions Attribute from a sorted array of individual options. 3109 LoopOptionsAttr LoopOptionsAttr::get( 3110 MLIRContext *context, 3111 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) { 3112 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) && 3113 "LoopOptionsAttr ctor expects a sorted options array"); 3114 return Base::get(context, sortedOptions); 3115 } 3116 3117 /// Build the LoopOptions Attribute from a sorted array of individual options. 3118 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context, 3119 LoopOptionsAttrBuilder &optionBuilders) { 3120 llvm::sort(optionBuilders.options, llvm::less_first()); 3121 return Base::get(context, optionBuilders.options); 3122 } 3123 3124 void LoopOptionsAttr::print(AsmPrinter &printer) const { 3125 printer << "<"; 3126 llvm::interleaveComma(getOptions(), printer, [&](auto option) { 3127 printer << stringifyEnum(option.first) << " = "; 3128 switch (option.first) { 3129 case LoopOptionCase::disable_licm: 3130 case LoopOptionCase::disable_unroll: 3131 case LoopOptionCase::disable_pipeline: 3132 printer << (option.second ? "true" : "false"); 3133 break; 3134 case LoopOptionCase::interleave_count: 3135 case LoopOptionCase::pipeline_initiation_interval: 3136 printer << option.second; 3137 break; 3138 } 3139 }); 3140 printer << ">"; 3141 } 3142 3143 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) { 3144 if (failed(parser.parseLess())) 3145 return {}; 3146 3147 SmallVector<std::pair<LoopOptionCase, int64_t>> options; 3148 llvm::SmallDenseSet<LoopOptionCase> seenOptions; 3149 auto parseLoopOptions = [&]() -> ParseResult { 3150 StringRef optionName; 3151 if (parser.parseKeyword(&optionName)) 3152 return failure(); 3153 3154 auto option = symbolizeLoopOptionCase(optionName); 3155 if (!option) 3156 return parser.emitError(parser.getNameLoc(), "unknown loop option: ") 3157 << optionName; 3158 if (!seenOptions.insert(*option).second) 3159 return parser.emitError(parser.getNameLoc(), "loop option present twice"); 3160 if (failed(parser.parseEqual())) 3161 return failure(); 3162 3163 int64_t value; 3164 switch (*option) { 3165 case LoopOptionCase::disable_licm: 3166 case LoopOptionCase::disable_unroll: 3167 case LoopOptionCase::disable_pipeline: 3168 if (succeeded(parser.parseOptionalKeyword("true"))) 3169 value = 1; 3170 else if (succeeded(parser.parseOptionalKeyword("false"))) 3171 value = 0; 3172 else { 3173 return parser.emitError(parser.getNameLoc(), 3174 "expected boolean value 'true' or 'false'"); 3175 } 3176 break; 3177 case LoopOptionCase::interleave_count: 3178 case LoopOptionCase::pipeline_initiation_interval: 3179 if (failed(parser.parseInteger(value))) 3180 return parser.emitError(parser.getNameLoc(), "expected integer value"); 3181 break; 3182 } 3183 options.push_back(std::make_pair(*option, value)); 3184 return success(); 3185 }; 3186 if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater()) 3187 return {}; 3188 3189 llvm::sort(options, llvm::less_first()); 3190 return get(parser.getContext(), options); 3191 } 3192