1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "TypeDetail.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Matchers.h"
23 
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/Bitcode/BitcodeReader.h"
28 #include "llvm/Bitcode/BitcodeWriter.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Error.h"
33 #include "llvm/Support/Mutex.h"
34 #include "llvm/Support/SourceMgr.h"
35 
36 #include <numeric>
37 
38 using namespace mlir;
39 using namespace mlir::LLVM;
40 using mlir::LLVM::cconv::getMaxEnumValForCConv;
41 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
42 
43 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
44 
45 static constexpr const char kVolatileAttrName[] = "volatile_";
46 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
47 static constexpr const char kElemTypeAttrName[] = "elem_type";
48 
49 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
50 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
51 #define GET_ATTRDEF_CLASSES
52 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
53 
processFMFAttr(ArrayRef<NamedAttribute> attrs)54 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
55   SmallVector<NamedAttribute, 8> filteredAttrs(
56       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
57         if (attr.getName() == "fastmathFlags") {
58           auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
59           return defAttr != attr.getValue();
60         }
61         return true;
62       }));
63   return filteredAttrs;
64 }
65 
parseLLVMOpAttrs(OpAsmParser & parser,NamedAttrList & result)66 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
67                                     NamedAttrList &result) {
68   return parser.parseOptionalAttrDict(result);
69 }
70 
printLLVMOpAttrs(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)71 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
72                              DictionaryAttr attrs) {
73   printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
74 }
75 
76 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
77 /// fully defined llvm.func.
verifySymbolAttrUse(FlatSymbolRefAttr symbol,Operation * op,SymbolTableCollection & symbolTable)78 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
79                                          Operation *op,
80                                          SymbolTableCollection &symbolTable) {
81   StringRef name = symbol.getValue();
82   auto func =
83       symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
84   if (!func)
85     return op->emitOpError("'")
86            << name << "' does not reference a valid LLVM function";
87   if (func.isExternal())
88     return op->emitOpError("'") << name << "' does not have a definition";
89   return success();
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // Printing, parsing and builder for LLVM::CmpOp.
94 //===----------------------------------------------------------------------===//
95 
build(OpBuilder & builder,OperationState & result,ICmpPredicate predicate,Value lhs,Value rhs)96 void ICmpOp::build(OpBuilder &builder, OperationState &result,
97                    ICmpPredicate predicate, Value lhs, Value rhs) {
98   auto boolType = IntegerType::get(lhs.getType().getContext(), 1);
99   if (LLVM::isCompatibleVectorType(lhs.getType()) ||
100       LLVM::isCompatibleVectorType(rhs.getType())) {
101     int64_t numLHSElements = 1, numRHSElements = 1;
102     if (LLVM::isCompatibleVectorType(lhs.getType()))
103       numLHSElements =
104           LLVM::getVectorNumElements(lhs.getType()).getFixedValue();
105     if (LLVM::isCompatibleVectorType(rhs.getType()))
106       numRHSElements =
107           LLVM::getVectorNumElements(rhs.getType()).getFixedValue();
108     build(builder, result,
109           VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType),
110           predicate, lhs, rhs);
111   } else {
112     build(builder, result, boolType, predicate, lhs, rhs);
113   }
114 }
115 
print(OpAsmPrinter & p)116 void ICmpOp::print(OpAsmPrinter &p) {
117   p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
118     << ", " << getOperand(1);
119   p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
120   p << " : " << getLhs().getType();
121 }
122 
print(OpAsmPrinter & p)123 void FCmpOp::print(OpAsmPrinter &p) {
124   p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
125     << ", " << getOperand(1);
126   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
127   p << " : " << getLhs().getType();
128 }
129 
130 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
131 //                 attribute-dict? `:` type
132 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
133 //                 attribute-dict? `:` type
134 template <typename CmpPredicateType>
parseCmpOp(OpAsmParser & parser,OperationState & result)135 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
136   Builder &builder = parser.getBuilder();
137 
138   StringAttr predicateAttr;
139   OpAsmParser::UnresolvedOperand lhs, rhs;
140   Type type;
141   SMLoc predicateLoc, trailingTypeLoc;
142   if (parser.getCurrentLocation(&predicateLoc) ||
143       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
144       parser.parseOperand(lhs) || parser.parseComma() ||
145       parser.parseOperand(rhs) ||
146       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
147       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
148       parser.resolveOperand(lhs, type, result.operands) ||
149       parser.resolveOperand(rhs, type, result.operands))
150     return failure();
151 
152   // Replace the string attribute `predicate` with an integer attribute.
153   int64_t predicateValue = 0;
154   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
155     Optional<ICmpPredicate> predicate =
156         symbolizeICmpPredicate(predicateAttr.getValue());
157     if (!predicate)
158       return parser.emitError(predicateLoc)
159              << "'" << predicateAttr.getValue()
160              << "' is an incorrect value of the 'predicate' attribute";
161     predicateValue = static_cast<int64_t>(*predicate);
162   } else {
163     Optional<FCmpPredicate> predicate =
164         symbolizeFCmpPredicate(predicateAttr.getValue());
165     if (!predicate)
166       return parser.emitError(predicateLoc)
167              << "'" << predicateAttr.getValue()
168              << "' is an incorrect value of the 'predicate' attribute";
169     predicateValue = static_cast<int64_t>(*predicate);
170   }
171 
172   result.attributes.set("predicate",
173                         parser.getBuilder().getI64IntegerAttr(predicateValue));
174 
175   // The result type is either i1 or a vector type <? x i1> if the inputs are
176   // vectors.
177   Type resultType = IntegerType::get(builder.getContext(), 1);
178   if (!isCompatibleType(type))
179     return parser.emitError(trailingTypeLoc,
180                             "expected LLVM dialect-compatible type");
181   if (LLVM::isCompatibleVectorType(type)) {
182     if (LLVM::isScalableVectorType(type)) {
183       resultType = LLVM::getVectorType(
184           resultType, LLVM::getVectorNumElements(type).getKnownMinValue(),
185           /*isScalable=*/true);
186     } else {
187       resultType = LLVM::getVectorType(
188           resultType, LLVM::getVectorNumElements(type).getFixedValue(),
189           /*isScalable=*/false);
190     }
191   }
192 
193   result.addTypes({resultType});
194   return success();
195 }
196 
parse(OpAsmParser & parser,OperationState & result)197 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
198   return parseCmpOp<ICmpPredicate>(parser, result);
199 }
200 
parse(OpAsmParser & parser,OperationState & result)201 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
202   return parseCmpOp<FCmpPredicate>(parser, result);
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // Printing, parsing and verification for LLVM::AllocaOp.
207 //===----------------------------------------------------------------------===//
208 
print(OpAsmPrinter & p)209 void AllocaOp::print(OpAsmPrinter &p) {
210   Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType();
211   if (!elemTy)
212     elemTy = *getElemType();
213 
214   auto funcTy =
215       FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
216 
217   p << ' ' << getArraySize() << " x " << elemTy;
218   if (getAlignment() && *getAlignment() != 0)
219     p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName});
220   else
221     p.printOptionalAttrDict((*this)->getAttrs(),
222                             {"alignment", kElemTypeAttrName});
223   p << " : " << funcTy;
224 }
225 
226 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
227 //                 `:` type `,` type
parse(OpAsmParser & parser,OperationState & result)228 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
229   OpAsmParser::UnresolvedOperand arraySize;
230   Type type, elemType;
231   SMLoc trailingTypeLoc;
232   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
233       parser.parseType(elemType) ||
234       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
235       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
236     return failure();
237 
238   Optional<NamedAttribute> alignmentAttr =
239       result.attributes.getNamed("alignment");
240   if (alignmentAttr.has_value()) {
241     auto alignmentInt =
242         alignmentAttr.value().getValue().dyn_cast<IntegerAttr>();
243     if (!alignmentInt)
244       return parser.emitError(parser.getNameLoc(),
245                               "expected integer alignment");
246     if (alignmentInt.getValue().isNullValue())
247       result.attributes.erase("alignment");
248   }
249 
250   // Extract the result type from the trailing function type.
251   auto funcType = type.dyn_cast<FunctionType>();
252   if (!funcType || funcType.getNumInputs() != 1 ||
253       funcType.getNumResults() != 1)
254     return parser.emitError(
255         trailingTypeLoc,
256         "expected trailing function type with one argument and one result");
257 
258   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
259     return failure();
260 
261   Type resultType = funcType.getResult(0);
262   if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) {
263     if (ptrResultType.isOpaque())
264       result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
265   }
266 
267   result.addTypes({funcType.getResult(0)});
268   return success();
269 }
270 
271 /// Checks that the elemental type is present in either the pointer type or
272 /// the attribute, but not both.
verifyOpaquePtr(Operation * op,LLVMPointerType ptrType,Optional<Type> ptrElementType)273 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType,
274                                      Optional<Type> ptrElementType) {
275   if (ptrType.isOpaque() && !ptrElementType.has_value()) {
276     return op->emitOpError() << "expected '" << kElemTypeAttrName
277                              << "' attribute if opaque pointer type is used";
278   }
279   if (!ptrType.isOpaque() && ptrElementType.has_value()) {
280     return op->emitOpError()
281            << "unexpected '" << kElemTypeAttrName
282            << "' attribute when non-opaque pointer type is used";
283   }
284   return success();
285 }
286 
verify()287 LogicalResult AllocaOp::verify() {
288   return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(),
289                          getElemType());
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // LLVM::BrOp
294 //===----------------------------------------------------------------------===//
295 
getSuccessorOperands(unsigned index)296 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
297   assert(index == 0 && "invalid successor index");
298   return SuccessorOperands(getDestOperandsMutable());
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // LLVM::CondBrOp
303 //===----------------------------------------------------------------------===//
304 
getSuccessorOperands(unsigned index)305 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
306   assert(index < getNumSuccessors() && "invalid successor index");
307   return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
308                                       : getFalseDestOperandsMutable());
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // LLVM::SwitchOp
313 //===----------------------------------------------------------------------===//
314 
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<int32_t> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands,ArrayRef<int32_t> branchWeights)315 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
316                      Block *defaultDestination, ValueRange defaultOperands,
317                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
318                      ArrayRef<ValueRange> caseOperands,
319                      ArrayRef<int32_t> branchWeights) {
320   ElementsAttr caseValuesAttr;
321   if (!caseValues.empty())
322     caseValuesAttr = builder.getI32VectorAttr(caseValues);
323 
324   ElementsAttr weightsAttr;
325   if (!branchWeights.empty())
326     weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
327 
328   build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
329         weightsAttr, defaultDestination, caseDestinations);
330 }
331 
332 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
333 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
parseSwitchOpCases(OpAsmParser & parser,Type flagType,ElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> & caseOperands,SmallVectorImpl<SmallVector<Type>> & caseOperandTypes)334 static ParseResult parseSwitchOpCases(
335     OpAsmParser &parser, Type flagType, ElementsAttr &caseValues,
336     SmallVectorImpl<Block *> &caseDestinations,
337     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
338     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
339   SmallVector<APInt> values;
340   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
341   do {
342     int64_t value = 0;
343     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
344     if (values.empty() && !integerParseResult.hasValue())
345       return success();
346 
347     if (!integerParseResult.hasValue() || integerParseResult.getValue())
348       return failure();
349     values.push_back(APInt(bitWidth, value));
350 
351     Block *destination;
352     SmallVector<OpAsmParser::UnresolvedOperand> operands;
353     SmallVector<Type> operandTypes;
354     if (parser.parseColon() || parser.parseSuccessor(destination))
355       return failure();
356     if (!parser.parseOptionalLParen()) {
357       if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
358                                   /*allowResultNumber=*/false) ||
359           parser.parseColonTypeList(operandTypes) || parser.parseRParen())
360         return failure();
361     }
362     caseDestinations.push_back(destination);
363     caseOperands.emplace_back(operands);
364     caseOperandTypes.emplace_back(operandTypes);
365   } while (!parser.parseOptionalComma());
366 
367   ShapedType caseValueType =
368       VectorType::get(static_cast<int64_t>(values.size()), flagType);
369   caseValues = DenseIntElementsAttr::get(caseValueType, values);
370   return success();
371 }
372 
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,ElementsAttr caseValues,SuccessorRange caseDestinations,OperandRangeRange caseOperands,const TypeRangeRange & caseOperandTypes)373 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
374                                ElementsAttr caseValues,
375                                SuccessorRange caseDestinations,
376                                OperandRangeRange caseOperands,
377                                const TypeRangeRange &caseOperandTypes) {
378   if (!caseValues)
379     return;
380 
381   size_t index = 0;
382   llvm::interleave(
383       llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
384       [&](auto i) {
385         p << "  ";
386         p << std::get<0>(i).getLimitedValue();
387         p << ": ";
388         p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
389       },
390       [&] {
391         p << ',';
392         p.printNewline();
393       });
394   p.printNewline();
395 }
396 
verify()397 LogicalResult SwitchOp::verify() {
398   if ((!getCaseValues() && !getCaseDestinations().empty()) ||
399       (getCaseValues() &&
400        getCaseValues()->size() !=
401            static_cast<int64_t>(getCaseDestinations().size())))
402     return emitOpError("expects number of case values to match number of "
403                        "case destinations");
404   if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
405     return emitError("expects number of branch weights to match number of "
406                      "successors: ")
407            << getBranchWeights()->size() << " vs " << getNumSuccessors();
408   return success();
409 }
410 
getSuccessorOperands(unsigned index)411 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
412   assert(index < getNumSuccessors() && "invalid successor index");
413   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
414                                       : getCaseOperandsMutable(index - 1));
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // Code for LLVM::GEPOp.
419 //===----------------------------------------------------------------------===//
420 
421 constexpr int GEPOp::kDynamicIndex;
422 
423 namespace {
424 /// Base class for llvm::Error related to GEP index.
425 class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> {
426 protected:
427   unsigned indexPos;
428 
429 public:
430   static char ID;
431 
convertToErrorCode() const432   std::error_code convertToErrorCode() const override {
433     return llvm::inconvertibleErrorCode();
434   }
435 
GEPIndexError(unsigned pos)436   explicit GEPIndexError(unsigned pos) : indexPos(pos) {}
437 };
438 
439 /// llvm::Error for out-of-bound GEP index.
440 struct GEPIndexOutOfBoundError
441     : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> {
442   static char ID;
443 
444   using ErrorInfo::ErrorInfo;
445 
log__anoncc58d73e0411::GEPIndexOutOfBoundError446   void log(llvm::raw_ostream &os) const override {
447     os << "index " << indexPos << " indexing a struct is out of bounds";
448   }
449 };
450 
451 /// llvm::Error for non-static GEP index indexing a struct.
452 struct GEPStaticIndexError
453     : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> {
454   static char ID;
455 
456   using ErrorInfo::ErrorInfo;
457 
log__anoncc58d73e0411::GEPStaticIndexError458   void log(llvm::raw_ostream &os) const override {
459     os << "expected index " << indexPos << " indexing a struct "
460        << "to be constant";
461   }
462 };
463 } // end anonymous namespace
464 
465 char GEPIndexError::ID = 0;
466 char GEPIndexOutOfBoundError::ID = 0;
467 char GEPStaticIndexError::ID = 0;
468 
469 /// For the given `structIndices` and `indices`, check if they're complied
470 /// with `baseGEPType`, especially check against LLVMStructTypes nested within,
471 /// and refine/promote struct index from `indices` to `updatedStructIndices`
472 /// if the latter argument is not null.
473 static llvm::Error
recordStructIndices(Type baseGEPType,unsigned indexPos,ArrayRef<int32_t> structIndices,ValueRange indices,SmallVectorImpl<int32_t> * updatedStructIndices,SmallVectorImpl<Value> * remainingIndices)474 recordStructIndices(Type baseGEPType, unsigned indexPos,
475                     ArrayRef<int32_t> structIndices, ValueRange indices,
476                     SmallVectorImpl<int32_t> *updatedStructIndices,
477                     SmallVectorImpl<Value> *remainingIndices) {
478   if (indexPos >= structIndices.size())
479     // Stop searching
480     return llvm::Error::success();
481 
482   int32_t gepIndex = structIndices[indexPos];
483   bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex;
484 
485   unsigned dynamicIndexPos = indexPos;
486   if (!isStaticIndex)
487     dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1),
488                                   LLVM::GEPOp::kDynamicIndex) -
489                       1;
490 
491   return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
492       .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
493         // We don't always want to refine the index (e.g. when performing
494         // verification), so we only refine when updatedStructIndices is not
495         // null.
496         if (!isStaticIndex && updatedStructIndices) {
497           // Try to refine.
498           APInt staticIndexValue;
499           isStaticIndex = matchPattern(indices[dynamicIndexPos],
500                                        m_ConstantInt(&staticIndexValue));
501           if (isStaticIndex) {
502             assert(staticIndexValue.getBitWidth() <= 64 &&
503                    llvm::isInt<32>(staticIndexValue.getLimitedValue()) &&
504                    "struct index can't fit within int32_t");
505             gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
506           }
507         }
508         if (!isStaticIndex)
509           return llvm::make_error<GEPStaticIndexError>(indexPos);
510 
511         ArrayRef<Type> elementTypes = structType.getBody();
512         if (gepIndex < 0 ||
513             static_cast<size_t>(gepIndex) >= elementTypes.size())
514           return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
515 
516         if (updatedStructIndices)
517           (*updatedStructIndices)[indexPos] = gepIndex;
518 
519         // Instead of recusively going into every children types, we only
520         // dive into the one indexed by gepIndex.
521         return recordStructIndices(elementTypes[gepIndex], indexPos + 1,
522                                    structIndices, indices, updatedStructIndices,
523                                    remainingIndices);
524       })
525       .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
526             LLVMArrayType>([&](auto containerType) -> llvm::Error {
527         // Currently we don't refine non-struct index even if it's static.
528         if (remainingIndices)
529           remainingIndices->push_back(indices[dynamicIndexPos]);
530         return recordStructIndices(containerType.getElementType(), indexPos + 1,
531                                    structIndices, indices, updatedStructIndices,
532                                    remainingIndices);
533       })
534       .Default(
535           [](auto otherType) -> llvm::Error { return llvm::Error::success(); });
536 }
537 
538 /// Driver function around `recordStructIndices`. Note that we always check
539 /// from the second GEP index since the first one is always dynamic.
540 static llvm::Error
findStructIndices(Type baseGEPType,ArrayRef<int32_t> structIndices,ValueRange indices,SmallVectorImpl<int32_t> * updatedStructIndices=nullptr,SmallVectorImpl<Value> * remainingIndices=nullptr)541 findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices,
542                   ValueRange indices,
543                   SmallVectorImpl<int32_t> *updatedStructIndices = nullptr,
544                   SmallVectorImpl<Value> *remainingIndices = nullptr) {
545   if (remainingIndices)
546     // The first GEP index is always dynamic.
547     remainingIndices->push_back(indices[0]);
548   return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices,
549                              indices, updatedStructIndices, remainingIndices);
550 }
551 
build(OpBuilder & builder,OperationState & result,Type resultType,Value basePtr,ValueRange operands,ArrayRef<NamedAttribute> attributes)552 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
553                   Value basePtr, ValueRange operands,
554                   ArrayRef<NamedAttribute> attributes) {
555   build(builder, result, resultType, basePtr, operands,
556         SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes);
557 }
558 
559 /// Returns the elemental type of any LLVM-compatible vector type or self.
extractVectorElementType(Type type)560 static Type extractVectorElementType(Type type) {
561   if (auto vectorType = type.dyn_cast<VectorType>())
562     return vectorType.getElementType();
563   if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
564     return scalableVectorType.getElementType();
565   if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
566     return fixedVectorType.getElementType();
567   return type;
568 }
569 
build(OpBuilder & builder,OperationState & result,Type resultType,Type elementType,Value basePtr,ValueRange indices,ArrayRef<NamedAttribute> attributes)570 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
571                   Type elementType, Value basePtr, ValueRange indices,
572                   ArrayRef<NamedAttribute> attributes) {
573   build(builder, result, resultType, elementType, basePtr, indices,
574         SmallVector<int32_t>(indices.size(), kDynamicIndex), attributes);
575 }
576 
build(OpBuilder & builder,OperationState & result,Type resultType,Value basePtr,ValueRange indices,ArrayRef<int32_t> structIndices,ArrayRef<NamedAttribute> attributes)577 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
578                   Value basePtr, ValueRange indices,
579                   ArrayRef<int32_t> structIndices,
580                   ArrayRef<NamedAttribute> attributes) {
581   auto ptrType =
582       extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>();
583   assert(!ptrType.isOpaque() &&
584          "expected non-opaque pointer, provide elementType explicitly when "
585          "opaque pointers are used");
586   build(builder, result, resultType, ptrType.getElementType(), basePtr, indices,
587         structIndices, attributes);
588 }
589 
build(OpBuilder & builder,OperationState & result,Type resultType,Type elementType,Value basePtr,ValueRange indices,ArrayRef<int32_t> structIndices,ArrayRef<NamedAttribute> attributes)590 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
591                   Type elementType, Value basePtr, ValueRange indices,
592                   ArrayRef<int32_t> structIndices,
593                   ArrayRef<NamedAttribute> attributes) {
594   SmallVector<Value> remainingIndices;
595   SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
596                                             structIndices.end());
597   if (llvm::Error err =
598           findStructIndices(elementType, structIndices, indices,
599                             &updatedStructIndices, &remainingIndices))
600     llvm::report_fatal_error(StringRef(llvm::toString(std::move(err))));
601 
602   assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
603                                         updatedStructIndices, kDynamicIndex)) &&
604          "expected as many index operands as dynamic index attr elements");
605 
606   result.addTypes(resultType);
607   result.addAttributes(attributes);
608   result.addAttribute("structIndices",
609                       builder.getI32TensorAttr(updatedStructIndices));
610   if (extractVectorElementType(basePtr.getType())
611           .cast<LLVMPointerType>()
612           .isOpaque())
613     result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
614   result.addOperands(basePtr);
615   result.addOperands(remainingIndices);
616 }
617 
618 static ParseResult
parseGEPIndices(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & indices,DenseIntElementsAttr & structIndices)619 parseGEPIndices(OpAsmParser &parser,
620                 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
621                 DenseIntElementsAttr &structIndices) {
622   SmallVector<int32_t> constantIndices;
623 
624   auto idxParser = [&]() -> ParseResult {
625     int32_t constantIndex;
626     OptionalParseResult parsedInteger =
627         parser.parseOptionalInteger(constantIndex);
628     if (parsedInteger.hasValue()) {
629       if (failed(parsedInteger.getValue()))
630         return failure();
631       constantIndices.push_back(constantIndex);
632       return success();
633     }
634 
635     constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
636     return parser.parseOperand(indices.emplace_back());
637   };
638   if (parser.parseCommaSeparatedList(idxParser))
639     return failure();
640 
641   structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
642   return success();
643 }
644 
printGEPIndices(OpAsmPrinter & printer,LLVM::GEPOp gepOp,OperandRange indices,DenseIntElementsAttr structIndices)645 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
646                             OperandRange indices,
647                             DenseIntElementsAttr structIndices) {
648   unsigned operandIdx = 0;
649   llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
650                         [&](int32_t cst) {
651                           if (cst == LLVM::GEPOp::kDynamicIndex)
652                             printer.printOperand(indices[operandIdx++]);
653                           else
654                             printer << cst;
655                         });
656 }
657 
verify()658 LogicalResult LLVM::GEPOp::verify() {
659   if (failed(verifyOpaquePtr(
660           getOperation(),
661           extractVectorElementType(getType()).cast<LLVMPointerType>(),
662           getElemType())))
663     return failure();
664 
665   auto structIndexRange = getStructIndices().getValues<int32_t>();
666   // structIndexRange is a kind of iterator, which cannot be converted
667   // to ArrayRef directly.
668   SmallVector<int32_t> structIndices(structIndexRange.size());
669   for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size()))
670     structIndices[i] = structIndexRange[i];
671   if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices,
672                                           getIndices()))
673     return emitOpError() << llvm::toString(std::move(err));
674 
675   return success();
676 }
677 
getSourceElementType()678 Type LLVM::GEPOp::getSourceElementType() {
679   if (Optional<Type> elemType = getElemType())
680     return *elemType;
681 
682   return extractVectorElementType(getBase().getType())
683       .cast<LLVMPointerType>()
684       .getElementType();
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // Builder, printer and parser for for LLVM::LoadOp.
689 //===----------------------------------------------------------------------===//
690 
verifySymbolAttribute(Operation * op,StringRef attributeName,llvm::function_ref<LogicalResult (Operation *,SymbolRefAttr)> verifySymbolType)691 LogicalResult verifySymbolAttribute(
692     Operation *op, StringRef attributeName,
693     llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
694         verifySymbolType) {
695   if (Attribute attribute = op->getAttr(attributeName)) {
696     // The attribute is already verified to be a symbol ref array attribute via
697     // a constraint in the operation definition.
698     for (SymbolRefAttr symbolRef :
699          attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
700       StringAttr metadataName = symbolRef.getRootReference();
701       StringAttr symbolName = symbolRef.getLeafReference();
702       // We want @metadata::@symbol, not just @symbol
703       if (metadataName == symbolName) {
704         return op->emitOpError() << "expected '" << symbolRef
705                                  << "' to specify a fully qualified reference";
706       }
707       auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
708           op->getParentOp(), metadataName);
709       if (!metadataOp)
710         return op->emitOpError()
711                << "expected '" << symbolRef << "' to reference a metadata op";
712       Operation *symbolOp =
713           SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
714       if (!symbolOp)
715         return op->emitOpError()
716                << "expected '" << symbolRef << "' to be a valid reference";
717       if (failed(verifySymbolType(symbolOp, symbolRef))) {
718         return failure();
719       }
720     }
721   }
722   return success();
723 }
724 
725 // Verifies that metadata ops are wired up properly.
726 template <typename OpTy>
verifyOpMetadata(Operation * op,StringRef attributeName)727 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
728   auto verifySymbolType = [op](Operation *symbolOp,
729                                SymbolRefAttr symbolRef) -> LogicalResult {
730     if (!isa<OpTy>(symbolOp)) {
731       return op->emitOpError()
732              << "expected '" << symbolRef << "' to resolve to a "
733              << OpTy::getOperationName();
734     }
735     return success();
736   };
737 
738   return verifySymbolAttribute(op, attributeName, verifySymbolType);
739 }
740 
verifyMemoryOpMetadata(Operation * op)741 static LogicalResult verifyMemoryOpMetadata(Operation *op) {
742   // access_groups
743   if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
744           op, LLVMDialect::getAccessGroupsAttrName())))
745     return failure();
746 
747   // alias_scopes
748   if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
749           op, LLVMDialect::getAliasScopesAttrName())))
750     return failure();
751 
752   // noalias_scopes
753   if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
754           op, LLVMDialect::getNoAliasScopesAttrName())))
755     return failure();
756 
757   return success();
758 }
759 
verify()760 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
761 
build(OpBuilder & builder,OperationState & result,Type t,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)762 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
763                    Value addr, unsigned alignment, bool isVolatile,
764                    bool isNonTemporal) {
765   result.addOperands(addr);
766   result.addTypes(t);
767   if (isVolatile)
768     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
769   if (isNonTemporal)
770     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
771   if (alignment != 0)
772     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
773 }
774 
print(OpAsmPrinter & p)775 void LoadOp::print(OpAsmPrinter &p) {
776   p << ' ';
777   if (getVolatile_())
778     p << "volatile ";
779   p << getAddr();
780   p.printOptionalAttrDict((*this)->getAttrs(),
781                           {kVolatileAttrName, kElemTypeAttrName});
782   p << " : " << getAddr().getType();
783   if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
784     p << " -> " << getType();
785 }
786 
787 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
788 // the resulting type if any, null type if opaque pointers are used, and None
789 // if the given type is not the pointer type.
getLoadStoreElementType(OpAsmParser & parser,Type type,SMLoc trailingTypeLoc)790 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type,
791                                               SMLoc trailingTypeLoc) {
792   auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
793   if (!llvmTy) {
794     parser.emitError(trailingTypeLoc, "expected LLVM pointer type");
795     return llvm::None;
796   }
797   return llvmTy.getElementType();
798 }
799 
800 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
801 //                 (`->` type)?
parse(OpAsmParser & parser,OperationState & result)802 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
803   OpAsmParser::UnresolvedOperand addr;
804   Type type;
805   SMLoc trailingTypeLoc;
806 
807   if (succeeded(parser.parseOptionalKeyword("volatile")))
808     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
809 
810   if (parser.parseOperand(addr) ||
811       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
812       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
813       parser.resolveOperand(addr, type, result.operands))
814     return failure();
815 
816   Optional<Type> elemTy =
817       getLoadStoreElementType(parser, type, trailingTypeLoc);
818   if (!elemTy)
819     return failure();
820   if (*elemTy) {
821     result.addTypes(*elemTy);
822     return success();
823   }
824 
825   Type trailingType;
826   if (parser.parseArrow() || parser.parseType(trailingType))
827     return failure();
828   result.addTypes(trailingType);
829   return success();
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // Builder, printer and parser for LLVM::StoreOp.
834 //===----------------------------------------------------------------------===//
835 
verify()836 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
837 
build(OpBuilder & builder,OperationState & result,Value value,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)838 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
839                     Value addr, unsigned alignment, bool isVolatile,
840                     bool isNonTemporal) {
841   result.addOperands({value, addr});
842   result.addTypes({});
843   if (isVolatile)
844     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
845   if (isNonTemporal)
846     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
847   if (alignment != 0)
848     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
849 }
850 
print(OpAsmPrinter & p)851 void StoreOp::print(OpAsmPrinter &p) {
852   p << ' ';
853   if (getVolatile_())
854     p << "volatile ";
855   p << getValue() << ", " << getAddr();
856   p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
857   p << " : ";
858   if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
859     p << getValue().getType() << ", ";
860   p << getAddr().getType();
861 }
862 
863 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
864 //                 attribute-dict? `:` type (`,` type)?
parse(OpAsmParser & parser,OperationState & result)865 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
866   OpAsmParser::UnresolvedOperand addr, value;
867   Type type;
868   SMLoc trailingTypeLoc;
869 
870   if (succeeded(parser.parseOptionalKeyword("volatile")))
871     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
872 
873   if (parser.parseOperand(value) || parser.parseComma() ||
874       parser.parseOperand(addr) ||
875       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
876       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
877     return failure();
878 
879   Type operandType;
880   if (succeeded(parser.parseOptionalComma())) {
881     operandType = type;
882     if (parser.parseType(type))
883       return failure();
884   } else {
885     Optional<Type> maybeOperandType =
886         getLoadStoreElementType(parser, type, trailingTypeLoc);
887     if (!maybeOperandType)
888       return failure();
889     operandType = *maybeOperandType;
890   }
891 
892   if (parser.resolveOperand(value, operandType, result.operands) ||
893       parser.resolveOperand(addr, type, result.operands))
894     return failure();
895 
896   return success();
897 }
898 
899 ///===---------------------------------------------------------------------===//
900 /// LLVM::InvokeOp
901 ///===---------------------------------------------------------------------===//
902 
getSuccessorOperands(unsigned index)903 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
904   assert(index < getNumSuccessors() && "invalid successor index");
905   return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
906                                       : getUnwindDestOperandsMutable());
907 }
908 
getCallableForCallee()909 CallInterfaceCallable InvokeOp::getCallableForCallee() {
910   // Direct call.
911   if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
912     return calleeAttr;
913   // Indirect call, callee Value is the first operand.
914   return getOperand(0);
915 }
916 
getArgOperands()917 Operation::operand_range InvokeOp::getArgOperands() {
918   return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
919 }
920 
verify()921 LogicalResult InvokeOp::verify() {
922   if (getNumResults() > 1)
923     return emitOpError("must have 0 or 1 result");
924 
925   Block *unwindDest = getUnwindDest();
926   if (unwindDest->empty())
927     return emitError("must have at least one operation in unwind destination");
928 
929   // In unwind destination, first operation must be LandingpadOp
930   if (!isa<LandingpadOp>(unwindDest->front()))
931     return emitError("first operation in unwind destination should be a "
932                      "llvm.landingpad operation");
933 
934   return success();
935 }
936 
print(OpAsmPrinter & p)937 void InvokeOp::print(OpAsmPrinter &p) {
938   auto callee = getCallee();
939   bool isDirect = callee.has_value();
940 
941   p << ' ';
942 
943   // Either function name or pointer
944   if (isDirect)
945     p.printSymbolName(callee.value());
946   else
947     p << getOperand(0);
948 
949   p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
950   p << " to ";
951   p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
952   p << " unwind ";
953   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
954 
955   p.printOptionalAttrDict((*this)->getAttrs(),
956                           {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
957   p << " : ";
958   p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
959                         getResultTypes());
960 }
961 
962 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
963 ///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
964 ///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
965 ///                  attribute-dict? `:` function-type
parse(OpAsmParser & parser,OperationState & result)966 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
967   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
968   FunctionType funcType;
969   SymbolRefAttr funcAttr;
970   SMLoc trailingTypeLoc;
971   Block *normalDest, *unwindDest;
972   SmallVector<Value, 4> normalOperands, unwindOperands;
973   Builder &builder = parser.getBuilder();
974 
975   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
976   // case of an indirect call, there will be 1 operand before `(`.  In case of a
977   // direct call, there will be no operands and the parser will stop at the
978   // function identifier without complaining.
979   if (parser.parseOperandList(operands))
980     return failure();
981   bool isDirect = operands.empty();
982 
983   // Optionally parse a function identifier.
984   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
985     return failure();
986 
987   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
988       parser.parseKeyword("to") ||
989       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
990       parser.parseKeyword("unwind") ||
991       parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
992       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
993       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
994     return failure();
995 
996   if (isDirect) {
997     // Make sure types match.
998     if (parser.resolveOperands(operands, funcType.getInputs(),
999                                parser.getNameLoc(), result.operands))
1000       return failure();
1001     result.addTypes(funcType.getResults());
1002   } else {
1003     // Construct the LLVM IR Dialect function type that the first operand
1004     // should match.
1005     if (funcType.getNumResults() > 1)
1006       return parser.emitError(trailingTypeLoc,
1007                               "expected function with 0 or 1 result");
1008 
1009     Type llvmResultType;
1010     if (funcType.getNumResults() == 0) {
1011       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
1012     } else {
1013       llvmResultType = funcType.getResult(0);
1014       if (!isCompatibleType(llvmResultType))
1015         return parser.emitError(trailingTypeLoc,
1016                                 "expected result to have LLVM type");
1017     }
1018 
1019     SmallVector<Type, 8> argTypes;
1020     argTypes.reserve(funcType.getNumInputs());
1021     for (Type ty : funcType.getInputs()) {
1022       if (isCompatibleType(ty))
1023         argTypes.push_back(ty);
1024       else
1025         return parser.emitError(trailingTypeLoc,
1026                                 "expected LLVM types as inputs");
1027     }
1028 
1029     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
1030     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
1031 
1032     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
1033 
1034     // Make sure that the first operand (indirect callee) matches the wrapped
1035     // LLVM IR function type, and that the types of the other call operands
1036     // match the types of the function arguments.
1037     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
1038         parser.resolveOperands(funcArguments, funcType.getInputs(),
1039                                parser.getNameLoc(), result.operands))
1040       return failure();
1041 
1042     result.addTypes(llvmResultType);
1043   }
1044   result.addSuccessors({normalDest, unwindDest});
1045   result.addOperands(normalOperands);
1046   result.addOperands(unwindOperands);
1047 
1048   result.addAttribute(
1049       InvokeOp::getOperandSegmentSizeAttr(),
1050       builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
1051                                 static_cast<int32_t>(normalOperands.size()),
1052                                 static_cast<int32_t>(unwindOperands.size())}));
1053   return success();
1054 }
1055 
1056 ///===----------------------------------------------------------------------===//
1057 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1058 ///===----------------------------------------------------------------------===//
1059 
verify()1060 LogicalResult LandingpadOp::verify() {
1061   Value value;
1062   if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1063     if (!func.getPersonality())
1064       return emitError(
1065           "llvm.landingpad needs to be in a function with a personality");
1066   }
1067 
1068   if (!getCleanup() && getOperands().empty())
1069     return emitError("landingpad instruction expects at least one clause or "
1070                      "cleanup attribute");
1071 
1072   for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1073     value = getOperand(idx);
1074     bool isFilter = value.getType().isa<LLVMArrayType>();
1075     if (isFilter) {
1076       // FIXME: Verify filter clauses when arrays are appropriately handled
1077     } else {
1078       // catch - global addresses only.
1079       // Bitcast ops should have global addresses as their args.
1080       if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1081         if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1082           continue;
1083         return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1084                << "global addresses expected as operand to "
1085                   "bitcast used in clauses for landingpad";
1086       }
1087       // NullOp and AddressOfOp allowed
1088       if (value.getDefiningOp<NullOp>())
1089         continue;
1090       if (value.getDefiningOp<AddressOfOp>())
1091         continue;
1092       return emitError("clause #")
1093              << idx << " is not a known constant - null, addressof, bitcast";
1094     }
1095   }
1096   return success();
1097 }
1098 
print(OpAsmPrinter & p)1099 void LandingpadOp::print(OpAsmPrinter &p) {
1100   p << (getCleanup() ? " cleanup " : " ");
1101 
1102   // Clauses
1103   for (auto value : getOperands()) {
1104     // Similar to llvm - if clause is an array type then it is filter
1105     // clause else catch clause
1106     bool isArrayTy = value.getType().isa<LLVMArrayType>();
1107     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1108       << value.getType() << ") ";
1109   }
1110 
1111   p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1112 
1113   p << ": " << getType();
1114 }
1115 
1116 /// <operation> ::= `llvm.landingpad` `cleanup`?
1117 ///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
parse(OpAsmParser & parser,OperationState & result)1118 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1119   // Check for cleanup
1120   if (succeeded(parser.parseOptionalKeyword("cleanup")))
1121     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1122 
1123   // Parse clauses with types
1124   while (succeeded(parser.parseOptionalLParen()) &&
1125          (succeeded(parser.parseOptionalKeyword("filter")) ||
1126           succeeded(parser.parseOptionalKeyword("catch")))) {
1127     OpAsmParser::UnresolvedOperand operand;
1128     Type ty;
1129     if (parser.parseOperand(operand) || parser.parseColon() ||
1130         parser.parseType(ty) ||
1131         parser.resolveOperand(operand, ty, result.operands) ||
1132         parser.parseRParen())
1133       return failure();
1134   }
1135 
1136   Type type;
1137   if (parser.parseColon() || parser.parseType(type))
1138     return failure();
1139 
1140   result.addTypes(type);
1141   return success();
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // Verifying/Printing/parsing for LLVM::CallOp.
1146 //===----------------------------------------------------------------------===//
1147 
getCallableForCallee()1148 CallInterfaceCallable CallOp::getCallableForCallee() {
1149   // Direct call.
1150   if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1151     return calleeAttr;
1152   // Indirect call, callee Value is the first operand.
1153   return getOperand(0);
1154 }
1155 
getArgOperands()1156 Operation::operand_range CallOp::getArgOperands() {
1157   return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
1158 }
1159 
verify()1160 LogicalResult CallOp::verify() {
1161   if (getNumResults() > 1)
1162     return emitOpError("must have 0 or 1 result");
1163 
1164   // Type for the callee, we'll get it differently depending if it is a direct
1165   // or indirect call.
1166   Type fnType;
1167 
1168   bool isIndirect = false;
1169 
1170   // If this is an indirect call, the callee attribute is missing.
1171   FlatSymbolRefAttr calleeName = getCalleeAttr();
1172   if (!calleeName) {
1173     isIndirect = true;
1174     if (!getNumOperands())
1175       return emitOpError(
1176           "must have either a `callee` attribute or at least an operand");
1177     auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
1178     if (!ptrType)
1179       return emitOpError("indirect call expects a pointer as callee: ")
1180              << ptrType;
1181     fnType = ptrType.getElementType();
1182   } else {
1183     Operation *callee =
1184         SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
1185     if (!callee)
1186       return emitOpError()
1187              << "'" << calleeName.getValue()
1188              << "' does not reference a symbol in the current scope";
1189     auto fn = dyn_cast<LLVMFuncOp>(callee);
1190     if (!fn)
1191       return emitOpError() << "'" << calleeName.getValue()
1192                            << "' does not reference a valid LLVM function";
1193 
1194     fnType = fn.getFunctionType();
1195   }
1196 
1197   LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
1198   if (!funcType)
1199     return emitOpError("callee does not have a functional type: ") << fnType;
1200 
1201   // Verify that the operand and result types match the callee.
1202 
1203   if (!funcType.isVarArg() &&
1204       funcType.getNumParams() != (getNumOperands() - isIndirect))
1205     return emitOpError() << "incorrect number of operands ("
1206                          << (getNumOperands() - isIndirect)
1207                          << ") for callee (expecting: "
1208                          << funcType.getNumParams() << ")";
1209 
1210   if (funcType.getNumParams() > (getNumOperands() - isIndirect))
1211     return emitOpError() << "incorrect number of operands ("
1212                          << (getNumOperands() - isIndirect)
1213                          << ") for varargs callee (expecting at least: "
1214                          << funcType.getNumParams() << ")";
1215 
1216   for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1217     if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1218       return emitOpError() << "operand type mismatch for operand " << i << ": "
1219                            << getOperand(i + isIndirect).getType()
1220                            << " != " << funcType.getParamType(i);
1221 
1222   if (getNumResults() == 0 &&
1223       !funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1224     return emitOpError() << "expected function call to produce a value";
1225 
1226   if (getNumResults() != 0 &&
1227       funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1228     return emitOpError()
1229            << "calling function with void result must not produce values";
1230 
1231   if (getNumResults() > 1)
1232     return emitOpError()
1233            << "expected LLVM function call to produce 0 or 1 result";
1234 
1235   if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
1236     return emitOpError() << "result type mismatch: " << getResult(0).getType()
1237                          << " != " << funcType.getReturnType();
1238 
1239   return success();
1240 }
1241 
print(OpAsmPrinter & p)1242 void CallOp::print(OpAsmPrinter &p) {
1243   auto callee = getCallee();
1244   bool isDirect = callee.has_value();
1245 
1246   // Print the direct callee if present as a function attribute, or an indirect
1247   // callee (first operand) otherwise.
1248   p << ' ';
1249   if (isDirect)
1250     p.printSymbolName(callee.value());
1251   else
1252     p << getOperand(0);
1253 
1254   auto args = getOperands().drop_front(isDirect ? 0 : 1);
1255   p << '(' << args << ')';
1256   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
1257 
1258   // Reconstruct the function MLIR function type from operand and result types.
1259   p << " : ";
1260   p.printFunctionalType(args.getTypes(), getResultTypes());
1261 }
1262 
1263 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
1264 //                 attribute-dict? `:` function-type
parse(OpAsmParser & parser,OperationState & result)1265 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1266   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1267   Type type;
1268   SymbolRefAttr funcAttr;
1269   SMLoc trailingTypeLoc;
1270 
1271   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
1272   // case of an indirect call, there will be 1 operand before `(`.  In case of a
1273   // direct call, there will be no operands and the parser will stop at the
1274   // function identifier without complaining.
1275   if (parser.parseOperandList(operands))
1276     return failure();
1277   bool isDirect = operands.empty();
1278 
1279   // Optionally parse a function identifier.
1280   if (isDirect)
1281     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1282       return failure();
1283 
1284   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1285       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1286       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
1287     return failure();
1288 
1289   auto funcType = type.dyn_cast<FunctionType>();
1290   if (!funcType)
1291     return parser.emitError(trailingTypeLoc, "expected function type");
1292   if (funcType.getNumResults() > 1)
1293     return parser.emitError(trailingTypeLoc,
1294                             "expected function with 0 or 1 result");
1295   if (isDirect) {
1296     // Make sure types match.
1297     if (parser.resolveOperands(operands, funcType.getInputs(),
1298                                parser.getNameLoc(), result.operands))
1299       return failure();
1300     if (funcType.getNumResults() != 0 &&
1301         !funcType.getResult(0).isa<LLVM::LLVMVoidType>())
1302       result.addTypes(funcType.getResults());
1303   } else {
1304     Builder &builder = parser.getBuilder();
1305     Type llvmResultType;
1306     if (funcType.getNumResults() == 0) {
1307       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
1308     } else {
1309       llvmResultType = funcType.getResult(0);
1310       if (!isCompatibleType(llvmResultType))
1311         return parser.emitError(trailingTypeLoc,
1312                                 "expected result to have LLVM type");
1313     }
1314 
1315     SmallVector<Type, 8> argTypes;
1316     argTypes.reserve(funcType.getNumInputs());
1317     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
1318       auto argType = funcType.getInput(i);
1319       if (!isCompatibleType(argType))
1320         return parser.emitError(trailingTypeLoc,
1321                                 "expected LLVM types as inputs");
1322       argTypes.push_back(argType);
1323     }
1324     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
1325     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
1326 
1327     auto funcArguments =
1328         ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front();
1329 
1330     // Make sure that the first operand (indirect callee) matches the wrapped
1331     // LLVM IR function type, and that the types of the other call operands
1332     // match the types of the function arguments.
1333     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
1334         parser.resolveOperands(funcArguments, funcType.getInputs(),
1335                                parser.getNameLoc(), result.operands))
1336       return failure();
1337 
1338     if (!llvmResultType.isa<LLVM::LLVMVoidType>())
1339       result.addTypes(llvmResultType);
1340   }
1341 
1342   return success();
1343 }
1344 
1345 //===----------------------------------------------------------------------===//
1346 // Printing/parsing for LLVM::ExtractElementOp.
1347 //===----------------------------------------------------------------------===//
1348 // Expects vector to be of wrapped LLVM vector type and position to be of
1349 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value vector,Value position,ArrayRef<NamedAttribute> attrs)1350 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
1351                                    Value vector, Value position,
1352                                    ArrayRef<NamedAttribute> attrs) {
1353   auto vectorType = vector.getType();
1354   auto llvmType = LLVM::getVectorElementType(vectorType);
1355   build(b, result, llvmType, vector, position);
1356   result.addAttributes(attrs);
1357 }
1358 
print(OpAsmPrinter & p)1359 void ExtractElementOp::print(OpAsmPrinter &p) {
1360   p << ' ' << getVector() << "[" << getPosition() << " : "
1361     << getPosition().getType() << "]";
1362   p.printOptionalAttrDict((*this)->getAttrs());
1363   p << " : " << getVector().getType();
1364 }
1365 
1366 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
1367 //                 attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1368 ParseResult ExtractElementOp::parse(OpAsmParser &parser,
1369                                     OperationState &result) {
1370   SMLoc loc;
1371   OpAsmParser::UnresolvedOperand vector, position;
1372   Type type, positionType;
1373   if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
1374       parser.parseLSquare() || parser.parseOperand(position) ||
1375       parser.parseColonType(positionType) || parser.parseRSquare() ||
1376       parser.parseOptionalAttrDict(result.attributes) ||
1377       parser.parseColonType(type) ||
1378       parser.resolveOperand(vector, type, result.operands) ||
1379       parser.resolveOperand(position, positionType, result.operands))
1380     return failure();
1381   if (!LLVM::isCompatibleVectorType(type))
1382     return parser.emitError(
1383         loc, "expected LLVM dialect-compatible vector type for operand #1");
1384   result.addTypes(LLVM::getVectorElementType(type));
1385   return success();
1386 }
1387 
verify()1388 LogicalResult ExtractElementOp::verify() {
1389   Type vectorType = getVector().getType();
1390   if (!LLVM::isCompatibleVectorType(vectorType))
1391     return emitOpError("expected LLVM dialect-compatible vector type for "
1392                        "operand #1, got")
1393            << vectorType;
1394   Type valueType = LLVM::getVectorElementType(vectorType);
1395   if (valueType != getRes().getType())
1396     return emitOpError() << "Type mismatch: extracting from " << vectorType
1397                          << " should produce " << valueType
1398                          << " but this op returns " << getRes().getType();
1399   return success();
1400 }
1401 
1402 //===----------------------------------------------------------------------===//
1403 // Printing/parsing for LLVM::ExtractValueOp.
1404 //===----------------------------------------------------------------------===//
1405 
print(OpAsmPrinter & p)1406 void ExtractValueOp::print(OpAsmPrinter &p) {
1407   p << ' ' << getContainer() << getPosition();
1408   p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1409   p << " : " << getContainer().getType();
1410 }
1411 
1412 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1413 // `containerType`.  Position is an integer array attribute where each value
1414 // is a zero-based position of the element in the aggregate type.  Return the
1415 // resulting type wrapped in MLIR, or nullptr on error.
getInsertExtractValueElementType(OpAsmParser & parser,Type containerType,ArrayAttr positionAttr,SMLoc attributeLoc,SMLoc typeLoc)1416 static Type getInsertExtractValueElementType(OpAsmParser &parser,
1417                                              Type containerType,
1418                                              ArrayAttr positionAttr,
1419                                              SMLoc attributeLoc,
1420                                              SMLoc typeLoc) {
1421   Type llvmType = containerType;
1422   if (!isCompatibleType(containerType))
1423     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
1424 
1425   // Infer the element type from the structure type: iteratively step inside the
1426   // type by taking the element type, indexed by the position attribute for
1427   // structures.  Check the position index before accessing, it is supposed to
1428   // be in bounds.
1429   for (Attribute subAttr : positionAttr) {
1430     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1431     if (!positionElementAttr)
1432       return parser.emitError(attributeLoc,
1433                               "expected an array of integer literals"),
1434              nullptr;
1435     int position = positionElementAttr.getInt();
1436     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1437       if (position < 0 ||
1438           static_cast<unsigned>(position) >= arrayType.getNumElements())
1439         return parser.emitError(attributeLoc, "position out of bounds"),
1440                nullptr;
1441       llvmType = arrayType.getElementType();
1442     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1443       if (position < 0 ||
1444           static_cast<unsigned>(position) >= structType.getBody().size())
1445         return parser.emitError(attributeLoc, "position out of bounds"),
1446                nullptr;
1447       llvmType = structType.getBody()[position];
1448     } else {
1449       return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
1450              nullptr;
1451     }
1452   }
1453   return llvmType;
1454 }
1455 
1456 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1457 // `containerType`. Returns null on failure.
getInsertExtractValueElementType(Type containerType,ArrayAttr positionAttr,Operation * op)1458 static Type getInsertExtractValueElementType(Type containerType,
1459                                              ArrayAttr positionAttr,
1460                                              Operation *op) {
1461   Type llvmType = containerType;
1462   if (!isCompatibleType(containerType)) {
1463     op->emitError("expected LLVM IR Dialect type, got ") << containerType;
1464     return {};
1465   }
1466 
1467   // Infer the element type from the structure type: iteratively step inside the
1468   // type by taking the element type, indexed by the position attribute for
1469   // structures.  Check the position index before accessing, it is supposed to
1470   // be in bounds.
1471   for (Attribute subAttr : positionAttr) {
1472     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1473     if (!positionElementAttr) {
1474       op->emitOpError("expected an array of integer literals, got: ")
1475           << subAttr;
1476       return {};
1477     }
1478     int position = positionElementAttr.getInt();
1479     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1480       if (position < 0 ||
1481           static_cast<unsigned>(position) >= arrayType.getNumElements()) {
1482         op->emitOpError("position out of bounds: ") << position;
1483         return {};
1484       }
1485       llvmType = arrayType.getElementType();
1486     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1487       if (position < 0 ||
1488           static_cast<unsigned>(position) >= structType.getBody().size()) {
1489         op->emitOpError("position out of bounds") << position;
1490         return {};
1491       }
1492       llvmType = structType.getBody()[position];
1493     } else {
1494       op->emitOpError("expected LLVM IR structure/array type, got: ")
1495           << llvmType;
1496       return {};
1497     }
1498   }
1499   return llvmType;
1500 }
1501 
1502 // <operation> ::= `llvm.extractvalue` ssa-use
1503 //                 `[` integer-literal (`,` integer-literal)* `]`
1504 //                 attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1505 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) {
1506   OpAsmParser::UnresolvedOperand container;
1507   Type containerType;
1508   ArrayAttr positionAttr;
1509   SMLoc attributeLoc, trailingTypeLoc;
1510 
1511   if (parser.parseOperand(container) ||
1512       parser.getCurrentLocation(&attributeLoc) ||
1513       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1514       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1515       parser.getCurrentLocation(&trailingTypeLoc) ||
1516       parser.parseType(containerType) ||
1517       parser.resolveOperand(container, containerType, result.operands))
1518     return failure();
1519 
1520   auto elementType = getInsertExtractValueElementType(
1521       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1522   if (!elementType)
1523     return failure();
1524 
1525   result.addTypes(elementType);
1526   return success();
1527 }
1528 
fold(ArrayRef<Attribute> operands)1529 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
1530   auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1531   OpFoldResult result = {};
1532   while (insertValueOp) {
1533     if (getPosition() == insertValueOp.getPosition())
1534       return insertValueOp.getValue();
1535     unsigned min =
1536         std::min(getPosition().size(), insertValueOp.getPosition().size());
1537     // If one is fully prefix of the other, stop propagating back as it will
1538     // miss dependencies. For instance, %3 should not fold to %f0 in the
1539     // following example:
1540     // ```
1541     //   %1 = llvm.insertvalue %f0, %0[0, 0] :
1542     //     !llvm.array<4 x !llvm.array<4xf32>>
1543     //   %2 = llvm.insertvalue %arr, %1[0] :
1544     //     !llvm.array<4 x !llvm.array<4xf32>>
1545     //   %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>>
1546     // ```
1547     if (getPosition().getValue().take_front(min) ==
1548         insertValueOp.getPosition().getValue().take_front(min))
1549       return result;
1550 
1551     // If neither a prefix, nor the exact position, we can extract out of the
1552     // value being inserted into. Moreover, we can try again if that operand
1553     // is itself an insertvalue expression.
1554     getContainerMutable().assign(insertValueOp.getContainer());
1555     result = getResult();
1556     insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1557   }
1558   return result;
1559 }
1560 
verify()1561 LogicalResult ExtractValueOp::verify() {
1562   Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1563                                                     getPositionAttr(), *this);
1564   if (!valueType)
1565     return failure();
1566 
1567   if (getRes().getType() != valueType)
1568     return emitOpError() << "Type mismatch: extracting from "
1569                          << getContainer().getType() << " should produce "
1570                          << valueType << " but this op returns "
1571                          << getRes().getType();
1572   return success();
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // Printing/parsing for LLVM::InsertElementOp.
1577 //===----------------------------------------------------------------------===//
1578 
print(OpAsmPrinter & p)1579 void InsertElementOp::print(OpAsmPrinter &p) {
1580   p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : "
1581     << getPosition().getType() << "]";
1582   p.printOptionalAttrDict((*this)->getAttrs());
1583   p << " : " << getVector().getType();
1584 }
1585 
1586 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
1587 //                 attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1588 ParseResult InsertElementOp::parse(OpAsmParser &parser,
1589                                    OperationState &result) {
1590   SMLoc loc;
1591   OpAsmParser::UnresolvedOperand vector, value, position;
1592   Type vectorType, positionType;
1593   if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
1594       parser.parseComma() || parser.parseOperand(vector) ||
1595       parser.parseLSquare() || parser.parseOperand(position) ||
1596       parser.parseColonType(positionType) || parser.parseRSquare() ||
1597       parser.parseOptionalAttrDict(result.attributes) ||
1598       parser.parseColonType(vectorType))
1599     return failure();
1600 
1601   if (!LLVM::isCompatibleVectorType(vectorType))
1602     return parser.emitError(
1603         loc, "expected LLVM dialect-compatible vector type for operand #1");
1604   Type valueType = LLVM::getVectorElementType(vectorType);
1605   if (!valueType)
1606     return failure();
1607 
1608   if (parser.resolveOperand(vector, vectorType, result.operands) ||
1609       parser.resolveOperand(value, valueType, result.operands) ||
1610       parser.resolveOperand(position, positionType, result.operands))
1611     return failure();
1612 
1613   result.addTypes(vectorType);
1614   return success();
1615 }
1616 
verify()1617 LogicalResult InsertElementOp::verify() {
1618   Type valueType = LLVM::getVectorElementType(getVector().getType());
1619   if (valueType != getValue().getType())
1620     return emitOpError() << "Type mismatch: cannot insert "
1621                          << getValue().getType() << " into "
1622                          << getVector().getType();
1623   return success();
1624 }
1625 
1626 //===----------------------------------------------------------------------===//
1627 // Printing/parsing for LLVM::InsertValueOp.
1628 //===----------------------------------------------------------------------===//
1629 
print(OpAsmPrinter & p)1630 void InsertValueOp::print(OpAsmPrinter &p) {
1631   p << ' ' << getValue() << ", " << getContainer() << getPosition();
1632   p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1633   p << " : " << getContainer().getType();
1634 }
1635 
1636 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
1637 //                 `[` integer-literal (`,` integer-literal)* `]`
1638 //                 attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1639 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) {
1640   OpAsmParser::UnresolvedOperand container, value;
1641   Type containerType;
1642   ArrayAttr positionAttr;
1643   SMLoc attributeLoc, trailingTypeLoc;
1644 
1645   if (parser.parseOperand(value) || parser.parseComma() ||
1646       parser.parseOperand(container) ||
1647       parser.getCurrentLocation(&attributeLoc) ||
1648       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1649       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1650       parser.getCurrentLocation(&trailingTypeLoc) ||
1651       parser.parseType(containerType))
1652     return failure();
1653 
1654   auto valueType = getInsertExtractValueElementType(
1655       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1656   if (!valueType)
1657     return failure();
1658 
1659   if (parser.resolveOperand(container, containerType, result.operands) ||
1660       parser.resolveOperand(value, valueType, result.operands))
1661     return failure();
1662 
1663   result.addTypes(containerType);
1664   return success();
1665 }
1666 
verify()1667 LogicalResult InsertValueOp::verify() {
1668   Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1669                                                     getPositionAttr(), *this);
1670   if (!valueType)
1671     return failure();
1672 
1673   if (getValue().getType() != valueType)
1674     return emitOpError() << "Type mismatch: cannot insert "
1675                          << getValue().getType() << " into "
1676                          << getContainer().getType();
1677 
1678   return success();
1679 }
1680 
1681 //===----------------------------------------------------------------------===//
1682 // Printing, parsing and verification for LLVM::ReturnOp.
1683 //===----------------------------------------------------------------------===//
1684 
verify()1685 LogicalResult ReturnOp::verify() {
1686   if (getNumOperands() > 1)
1687     return emitOpError("expected at most 1 operand");
1688 
1689   if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
1690     Type expectedType = parent.getFunctionType().getReturnType();
1691     if (expectedType.isa<LLVMVoidType>()) {
1692       if (getNumOperands() == 0)
1693         return success();
1694       InFlightDiagnostic diag = emitOpError("expected no operands");
1695       diag.attachNote(parent->getLoc()) << "when returning from function";
1696       return diag;
1697     }
1698     if (getNumOperands() == 0) {
1699       if (expectedType.isa<LLVMVoidType>())
1700         return success();
1701       InFlightDiagnostic diag = emitOpError("expected 1 operand");
1702       diag.attachNote(parent->getLoc()) << "when returning from function";
1703       return diag;
1704     }
1705     if (expectedType != getOperand(0).getType()) {
1706       InFlightDiagnostic diag = emitOpError("mismatching result types");
1707       diag.attachNote(parent->getLoc()) << "when returning from function";
1708       return diag;
1709     }
1710   }
1711   return success();
1712 }
1713 
1714 //===----------------------------------------------------------------------===//
1715 // ResumeOp
1716 //===----------------------------------------------------------------------===//
1717 
verify()1718 LogicalResult ResumeOp::verify() {
1719   if (!getValue().getDefiningOp<LandingpadOp>())
1720     return emitOpError("expects landingpad value as operand");
1721   // No check for personality of function - landingpad op verifies it.
1722   return success();
1723 }
1724 
1725 //===----------------------------------------------------------------------===//
1726 // Verifier for LLVM::AddressOfOp.
1727 //===----------------------------------------------------------------------===//
1728 
1729 template <typename OpTy>
lookupSymbolInModule(Operation * parent,StringRef name)1730 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
1731   Operation *module = parent;
1732   while (module && !satisfiesLLVMModule(module))
1733     module = module->getParentOp();
1734   assert(module && "unexpected operation outside of a module");
1735   return dyn_cast_or_null<OpTy>(
1736       mlir::SymbolTable::lookupSymbolIn(module, name));
1737 }
1738 
getGlobal()1739 GlobalOp AddressOfOp::getGlobal() {
1740   return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
1741                                               getGlobalName());
1742 }
1743 
getFunction()1744 LLVMFuncOp AddressOfOp::getFunction() {
1745   return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
1746                                                 getGlobalName());
1747 }
1748 
verify()1749 LogicalResult AddressOfOp::verify() {
1750   auto global = getGlobal();
1751   auto function = getFunction();
1752   if (!global && !function)
1753     return emitOpError(
1754         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1755 
1756   LLVMPointerType type = getType();
1757   if (global && global.getAddrSpace() != type.getAddressSpace())
1758     return emitOpError("pointer address space must match address space of the "
1759                        "referenced global");
1760 
1761   if (type.isOpaque())
1762     return success();
1763 
1764   if (global && type.getElementType() != global.getType())
1765     return emitOpError(
1766         "the type must be a pointer to the type of the referenced global");
1767 
1768   if (function && type.getElementType() != function.getFunctionType())
1769     return emitOpError(
1770         "the type must be a pointer to the type of the referenced function");
1771 
1772   return success();
1773 }
1774 
1775 //===----------------------------------------------------------------------===//
1776 // Builder, printer and verifier for LLVM::GlobalOp.
1777 //===----------------------------------------------------------------------===//
1778 
build(OpBuilder & builder,OperationState & result,Type type,bool isConstant,Linkage linkage,StringRef name,Attribute value,uint64_t alignment,unsigned addrSpace,bool dsoLocal,bool threadLocal,ArrayRef<NamedAttribute> attrs)1779 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1780                      bool isConstant, Linkage linkage, StringRef name,
1781                      Attribute value, uint64_t alignment, unsigned addrSpace,
1782                      bool dsoLocal, bool threadLocal,
1783                      ArrayRef<NamedAttribute> attrs) {
1784   result.addAttribute(getSymNameAttrName(result.name),
1785                       builder.getStringAttr(name));
1786   result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
1787   if (isConstant)
1788     result.addAttribute(getConstantAttrName(result.name),
1789                         builder.getUnitAttr());
1790   if (value)
1791     result.addAttribute(getValueAttrName(result.name), value);
1792   if (dsoLocal)
1793     result.addAttribute(getDsoLocalAttrName(result.name),
1794                         builder.getUnitAttr());
1795   if (threadLocal)
1796     result.addAttribute(getThreadLocal_AttrName(result.name),
1797                         builder.getUnitAttr());
1798 
1799   // Only add an alignment attribute if the "alignment" input
1800   // is different from 0. The value must also be a power of two, but
1801   // this is tested in GlobalOp::verify, not here.
1802   if (alignment != 0)
1803     result.addAttribute(getAlignmentAttrName(result.name),
1804                         builder.getI64IntegerAttr(alignment));
1805 
1806   result.addAttribute(getLinkageAttrName(result.name),
1807                       LinkageAttr::get(builder.getContext(), linkage));
1808   if (addrSpace != 0)
1809     result.addAttribute(getAddrSpaceAttrName(result.name),
1810                         builder.getI32IntegerAttr(addrSpace));
1811   result.attributes.append(attrs.begin(), attrs.end());
1812   result.addRegion();
1813 }
1814 
print(OpAsmPrinter & p)1815 void GlobalOp::print(OpAsmPrinter &p) {
1816   p << ' ' << stringifyLinkage(getLinkage()) << ' ';
1817   if (auto unnamedAddr = getUnnamedAddr()) {
1818     StringRef str = stringifyUnnamedAddr(*unnamedAddr);
1819     if (!str.empty())
1820       p << str << ' ';
1821   }
1822   if (getThreadLocal_())
1823     p << "thread_local ";
1824   if (getConstant())
1825     p << "constant ";
1826   p.printSymbolName(getSymName());
1827   p << '(';
1828   if (auto value = getValueOrNull())
1829     p.printAttribute(value);
1830   p << ')';
1831   // Note that the alignment attribute is printed using the
1832   // default syntax here, even though it is an inherent attribute
1833   // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1834   p.printOptionalAttrDict(
1835       (*this)->getAttrs(),
1836       {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(),
1837        getConstantAttrName(), getValueAttrName(), getLinkageAttrName(),
1838        getUnnamedAddrAttrName(), getThreadLocal_AttrName()});
1839 
1840   // Print the trailing type unless it's a string global.
1841   if (getValueOrNull().dyn_cast_or_null<StringAttr>())
1842     return;
1843   p << " : " << getType();
1844 
1845   Region &initializer = getInitializerRegion();
1846   if (!initializer.empty()) {
1847     p << ' ';
1848     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1849   }
1850 }
1851 
1852 // Parses one of the keywords provided in the list `keywords` and returns the
1853 // position of the parsed keyword in the list. If none of the keywords from the
1854 // list is parsed, returns -1.
parseOptionalKeywordAlternative(OpAsmParser & parser,ArrayRef<StringRef> keywords)1855 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1856                                            ArrayRef<StringRef> keywords) {
1857   for (const auto &en : llvm::enumerate(keywords)) {
1858     if (succeeded(parser.parseOptionalKeyword(en.value())))
1859       return en.index();
1860   }
1861   return -1;
1862 }
1863 
1864 namespace {
1865 template <typename Ty>
1866 struct EnumTraits {};
1867 
1868 #define REGISTER_ENUM_TYPE(Ty)                                                 \
1869   template <>                                                                  \
1870   struct EnumTraits<Ty> {                                                      \
1871     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
1872     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
1873   }
1874 
1875 REGISTER_ENUM_TYPE(Linkage);
1876 REGISTER_ENUM_TYPE(UnnamedAddr);
1877 REGISTER_ENUM_TYPE(CConv);
1878 } // namespace
1879 
1880 /// Parse an enum from the keyword, or default to the provided default value.
1881 /// The return type is the enum type by default, unless overriden with the
1882 /// second template argument.
1883 template <typename EnumTy, typename RetTy = EnumTy>
parseOptionalLLVMKeyword(OpAsmParser & parser,OperationState & result,EnumTy defaultValue)1884 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
1885                                       OperationState &result,
1886                                       EnumTy defaultValue) {
1887   SmallVector<StringRef, 10> names;
1888   for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
1889     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1890 
1891   int index = parseOptionalKeywordAlternative(parser, names);
1892   if (index == -1)
1893     return static_cast<RetTy>(defaultValue);
1894   return static_cast<RetTy>(index);
1895 }
1896 
1897 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1898 //               `(` attribute? `)` align? attribute-list? (`:` type)? region?
1899 // align     ::= `align` `=` UINT64
1900 //
1901 // The type can be omitted for string attributes, in which case it will be
1902 // inferred from the value of the string as [strlen(value) x i8].
parse(OpAsmParser & parser,OperationState & result)1903 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
1904   MLIRContext *ctx = parser.getContext();
1905   // Parse optional linkage, default to External.
1906   result.addAttribute(getLinkageAttrName(result.name),
1907                       LLVM::LinkageAttr::get(
1908                           ctx, parseOptionalLLVMKeyword<Linkage>(
1909                                    parser, result, LLVM::Linkage::External)));
1910 
1911   if (succeeded(parser.parseOptionalKeyword("thread_local")))
1912     result.addAttribute(getThreadLocal_AttrName(result.name),
1913                         parser.getBuilder().getUnitAttr());
1914 
1915   // Parse optional UnnamedAddr, default to None.
1916   result.addAttribute(getUnnamedAddrAttrName(result.name),
1917                       parser.getBuilder().getI64IntegerAttr(
1918                           parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
1919                               parser, result, LLVM::UnnamedAddr::None)));
1920 
1921   if (succeeded(parser.parseOptionalKeyword("constant")))
1922     result.addAttribute(getConstantAttrName(result.name),
1923                         parser.getBuilder().getUnitAttr());
1924 
1925   StringAttr name;
1926   if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
1927                              result.attributes) ||
1928       parser.parseLParen())
1929     return failure();
1930 
1931   Attribute value;
1932   if (parser.parseOptionalRParen()) {
1933     if (parser.parseAttribute(value, getValueAttrName(result.name),
1934                               result.attributes) ||
1935         parser.parseRParen())
1936       return failure();
1937   }
1938 
1939   SmallVector<Type, 1> types;
1940   if (parser.parseOptionalAttrDict(result.attributes) ||
1941       parser.parseOptionalColonTypeList(types))
1942     return failure();
1943 
1944   if (types.size() > 1)
1945     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1946 
1947   Region &initRegion = *result.addRegion();
1948   if (types.empty()) {
1949     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1950       MLIRContext *context = parser.getContext();
1951       auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1952                                                 strAttr.getValue().size());
1953       types.push_back(arrayType);
1954     } else {
1955       return parser.emitError(parser.getNameLoc(),
1956                               "type can only be omitted for string globals");
1957     }
1958   } else {
1959     OptionalParseResult parseResult =
1960         parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1961                                    /*argTypes=*/{});
1962     if (parseResult.hasValue() && failed(*parseResult))
1963       return failure();
1964   }
1965 
1966   result.addAttribute(getGlobalTypeAttrName(result.name),
1967                       TypeAttr::get(types[0]));
1968   return success();
1969 }
1970 
isZeroAttribute(Attribute value)1971 static bool isZeroAttribute(Attribute value) {
1972   if (auto intValue = value.dyn_cast<IntegerAttr>())
1973     return intValue.getValue().isNullValue();
1974   if (auto fpValue = value.dyn_cast<FloatAttr>())
1975     return fpValue.getValue().isZero();
1976   if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
1977     return isZeroAttribute(splatValue.getSplatValue<Attribute>());
1978   if (auto elementsValue = value.dyn_cast<ElementsAttr>())
1979     return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1980   if (auto arrayValue = value.dyn_cast<ArrayAttr>())
1981     return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1982   return false;
1983 }
1984 
verify()1985 LogicalResult GlobalOp::verify() {
1986   if (!LLVMPointerType::isValidElementType(getType()))
1987     return emitOpError(
1988         "expects type to be a valid element type for an LLVM pointer");
1989   if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
1990     return emitOpError("must appear at the module level");
1991 
1992   if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1993     auto type = getType().dyn_cast<LLVMArrayType>();
1994     IntegerType elementType =
1995         type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
1996     if (!elementType || elementType.getWidth() != 8 ||
1997         type.getNumElements() != strAttr.getValue().size())
1998       return emitOpError(
1999           "requires an i8 array type of the length equal to that of the string "
2000           "attribute");
2001   }
2002 
2003   if (getLinkage() == Linkage::Common) {
2004     if (Attribute value = getValueOrNull()) {
2005       if (!isZeroAttribute(value)) {
2006         return emitOpError()
2007                << "expected zero value for '"
2008                << stringifyLinkage(Linkage::Common) << "' linkage";
2009       }
2010     }
2011   }
2012 
2013   if (getLinkage() == Linkage::Appending) {
2014     if (!getType().isa<LLVMArrayType>()) {
2015       return emitOpError() << "expected array type for '"
2016                            << stringifyLinkage(Linkage::Appending)
2017                            << "' linkage";
2018     }
2019   }
2020 
2021   Optional<uint64_t> alignAttr = getAlignment();
2022   if (alignAttr.has_value()) {
2023     uint64_t value = alignAttr.value();
2024     if (!llvm::isPowerOf2_64(value))
2025       return emitError() << "alignment attribute is not a power of 2";
2026   }
2027 
2028   return success();
2029 }
2030 
verifyRegions()2031 LogicalResult GlobalOp::verifyRegions() {
2032   if (Block *b = getInitializerBlock()) {
2033     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2034     if (ret.operand_type_begin() == ret.operand_type_end())
2035       return emitOpError("initializer region cannot return void");
2036     if (*ret.operand_type_begin() != getType())
2037       return emitOpError("initializer region type ")
2038              << *ret.operand_type_begin() << " does not match global type "
2039              << getType();
2040 
2041     for (Operation &op : *b) {
2042       auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2043       if (!iface || !iface.hasNoEffect())
2044         return op.emitError()
2045                << "ops with side effects not allowed in global initializers";
2046     }
2047 
2048     if (getValueOrNull())
2049       return emitOpError("cannot have both initializer value and region");
2050   }
2051 
2052   return success();
2053 }
2054 
2055 //===----------------------------------------------------------------------===//
2056 // LLVM::GlobalCtorsOp
2057 //===----------------------------------------------------------------------===//
2058 
2059 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)2060 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2061   for (Attribute ctor : getCtors()) {
2062     if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this,
2063                                    symbolTable)))
2064       return failure();
2065   }
2066   return success();
2067 }
2068 
verify()2069 LogicalResult GlobalCtorsOp::verify() {
2070   if (getCtors().size() != getPriorities().size())
2071     return emitError(
2072         "mismatch between the number of ctors and the number of priorities");
2073   return success();
2074 }
2075 
2076 //===----------------------------------------------------------------------===//
2077 // LLVM::GlobalDtorsOp
2078 //===----------------------------------------------------------------------===//
2079 
2080 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)2081 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2082   for (Attribute dtor : getDtors()) {
2083     if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this,
2084                                    symbolTable)))
2085       return failure();
2086   }
2087   return success();
2088 }
2089 
verify()2090 LogicalResult GlobalDtorsOp::verify() {
2091   if (getDtors().size() != getPriorities().size())
2092     return emitError(
2093         "mismatch between the number of dtors and the number of priorities");
2094   return success();
2095 }
2096 
2097 //===----------------------------------------------------------------------===//
2098 // Printing/parsing for LLVM::ShuffleVectorOp.
2099 //===----------------------------------------------------------------------===//
2100 // Expects vector to be of wrapped LLVM vector type and position to be of
2101 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value v1,Value v2,ArrayAttr mask,ArrayRef<NamedAttribute> attrs)2102 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
2103                                   Value v1, Value v2, ArrayAttr mask,
2104                                   ArrayRef<NamedAttribute> attrs) {
2105   auto containerType = v1.getType();
2106   auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
2107                                    mask.size(),
2108                                    LLVM::isScalableVectorType(containerType));
2109   build(b, result, vType, v1, v2, mask);
2110   result.addAttributes(attrs);
2111 }
2112 
print(OpAsmPrinter & p)2113 void ShuffleVectorOp::print(OpAsmPrinter &p) {
2114   p << ' ' << getV1() << ", " << getV2() << " " << getMask();
2115   p.printOptionalAttrDict((*this)->getAttrs(), {"mask"});
2116   p << " : " << getV1().getType() << ", " << getV2().getType();
2117 }
2118 
2119 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
2120 //                 `[` integer-literal (`,` integer-literal)* `]`
2121 //                 attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)2122 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser,
2123                                    OperationState &result) {
2124   SMLoc loc;
2125   OpAsmParser::UnresolvedOperand v1, v2;
2126   ArrayAttr maskAttr;
2127   Type typeV1, typeV2;
2128   if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
2129       parser.parseComma() || parser.parseOperand(v2) ||
2130       parser.parseAttribute(maskAttr, "mask", result.attributes) ||
2131       parser.parseOptionalAttrDict(result.attributes) ||
2132       parser.parseColonType(typeV1) || parser.parseComma() ||
2133       parser.parseType(typeV2) ||
2134       parser.resolveOperand(v1, typeV1, result.operands) ||
2135       parser.resolveOperand(v2, typeV2, result.operands))
2136     return failure();
2137   if (!LLVM::isCompatibleVectorType(typeV1))
2138     return parser.emitError(
2139         loc, "expected LLVM IR dialect vector type for operand #1");
2140   auto vType =
2141       LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(),
2142                           LLVM::isScalableVectorType(typeV1));
2143   result.addTypes(vType);
2144   return success();
2145 }
2146 
verify()2147 LogicalResult ShuffleVectorOp::verify() {
2148   Type type1 = getV1().getType();
2149   Type type2 = getV2().getType();
2150   if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
2151     return emitOpError("expected matching LLVM IR Dialect element types");
2152   if (LLVM::isScalableVectorType(type1))
2153     if (llvm::any_of(getMask(), [](Attribute attr) {
2154           return attr.cast<IntegerAttr>().getInt() != 0;
2155         }))
2156       return emitOpError("expected a splat operation for scalable vectors");
2157   return success();
2158 }
2159 
2160 //===----------------------------------------------------------------------===//
2161 // Implementations for LLVM::LLVMFuncOp.
2162 //===----------------------------------------------------------------------===//
2163 
2164 // Add the entry block to the function.
addEntryBlock()2165 Block *LLVMFuncOp::addEntryBlock() {
2166   assert(empty() && "function already has an entry block");
2167 
2168   auto *entry = new Block;
2169   push_back(entry);
2170 
2171   // FIXME: Allow passing in proper locations for the entry arguments.
2172   LLVMFunctionType type = getFunctionType();
2173   for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2174     entry->addArgument(type.getParamType(i), getLoc());
2175   return entry;
2176 }
2177 
build(OpBuilder & builder,OperationState & result,StringRef name,Type type,LLVM::Linkage linkage,bool dsoLocal,CConv cconv,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)2178 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2179                        StringRef name, Type type, LLVM::Linkage linkage,
2180                        bool dsoLocal, CConv cconv,
2181                        ArrayRef<NamedAttribute> attrs,
2182                        ArrayRef<DictionaryAttr> argAttrs) {
2183   result.addRegion();
2184   result.addAttribute(SymbolTable::getSymbolAttrName(),
2185                       builder.getStringAttr(name));
2186   result.addAttribute(getFunctionTypeAttrName(result.name),
2187                       TypeAttr::get(type));
2188   result.addAttribute(getLinkageAttrName(result.name),
2189                       LinkageAttr::get(builder.getContext(), linkage));
2190   result.addAttribute(getCConvAttrName(result.name),
2191                       CConvAttr::get(builder.getContext(), cconv));
2192   result.attributes.append(attrs.begin(), attrs.end());
2193   if (dsoLocal)
2194     result.addAttribute("dso_local", builder.getUnitAttr());
2195   if (argAttrs.empty())
2196     return;
2197 
2198   assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
2199          "expected as many argument attribute lists as arguments");
2200   function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
2201                                                 /*resultAttrs=*/llvm::None);
2202 }
2203 
2204 // Builds an LLVM function type from the given lists of input and output types.
2205 // Returns a null type if any of the types provided are non-LLVM types, or if
2206 // there is more than one output type.
2207 static Type
buildLLVMFunctionType(OpAsmParser & parser,SMLoc loc,ArrayRef<Type> inputs,ArrayRef<Type> outputs,function_interface_impl::VariadicFlag variadicFlag)2208 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2209                       ArrayRef<Type> outputs,
2210                       function_interface_impl::VariadicFlag variadicFlag) {
2211   Builder &b = parser.getBuilder();
2212   if (outputs.size() > 1) {
2213     parser.emitError(loc, "failed to construct function type: expected zero or "
2214                           "one function result");
2215     return {};
2216   }
2217 
2218   // Convert inputs to LLVM types, exit early on error.
2219   SmallVector<Type, 4> llvmInputs;
2220   for (auto t : inputs) {
2221     if (!isCompatibleType(t)) {
2222       parser.emitError(loc, "failed to construct function type: expected LLVM "
2223                             "type for function arguments");
2224       return {};
2225     }
2226     llvmInputs.push_back(t);
2227   }
2228 
2229   // No output is denoted as "void" in LLVM type system.
2230   Type llvmOutput =
2231       outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2232   if (!isCompatibleType(llvmOutput)) {
2233     parser.emitError(loc, "failed to construct function type: expected LLVM "
2234                           "type for function results")
2235         << llvmOutput;
2236     return {};
2237   }
2238   return LLVMFunctionType::get(llvmOutput, llvmInputs,
2239                                variadicFlag.isVariadic());
2240 }
2241 
2242 // Parses an LLVM function.
2243 //
2244 // operation ::= `llvm.func` linkage? cconv? function-signature
2245 // function-attributes?
2246 //               function-body
2247 //
parse(OpAsmParser & parser,OperationState & result)2248 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2249   // Default to external linkage if no keyword is provided.
2250   result.addAttribute(
2251       getLinkageAttrName(result.name),
2252       LinkageAttr::get(parser.getContext(),
2253                        parseOptionalLLVMKeyword<Linkage>(
2254                            parser, result, LLVM::Linkage::External)));
2255 
2256   // Default to C Calling Convention if no keyword is provided.
2257   result.addAttribute(
2258       getCConvAttrName(result.name),
2259       CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
2260                                               parser, result, LLVM::CConv::C)));
2261 
2262   StringAttr nameAttr;
2263   SmallVector<OpAsmParser::Argument> entryArgs;
2264   SmallVector<DictionaryAttr> resultAttrs;
2265   SmallVector<Type> resultTypes;
2266   bool isVariadic;
2267 
2268   auto signatureLocation = parser.getCurrentLocation();
2269   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2270                              result.attributes) ||
2271       function_interface_impl::parseFunctionSignature(
2272           parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2273           resultAttrs))
2274     return failure();
2275 
2276   SmallVector<Type> argTypes;
2277   for (auto &arg : entryArgs)
2278     argTypes.push_back(arg.type);
2279   auto type =
2280       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2281                             function_interface_impl::VariadicFlag(isVariadic));
2282   if (!type)
2283     return failure();
2284   result.addAttribute(FunctionOpInterface::getTypeAttrName(),
2285                       TypeAttr::get(type));
2286 
2287   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2288     return failure();
2289   function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
2290                                                 entryArgs, resultAttrs);
2291 
2292   auto *body = result.addRegion();
2293   OptionalParseResult parseResult =
2294       parser.parseOptionalRegion(*body, entryArgs);
2295   return failure(parseResult.hasValue() && failed(*parseResult));
2296 }
2297 
2298 // Print the LLVMFuncOp. Collects argument and result types and passes them to
2299 // helper functions. Drops "void" result since it cannot be parsed back. Skips
2300 // the external linkage since it is the default value.
print(OpAsmPrinter & p)2301 void LLVMFuncOp::print(OpAsmPrinter &p) {
2302   p << ' ';
2303   if (getLinkage() != LLVM::Linkage::External)
2304     p << stringifyLinkage(getLinkage()) << ' ';
2305   if (getCConv() != LLVM::CConv::C)
2306     p << stringifyCConv(getCConv()) << ' ';
2307 
2308   p.printSymbolName(getName());
2309 
2310   LLVMFunctionType fnType = getFunctionType();
2311   SmallVector<Type, 8> argTypes;
2312   SmallVector<Type, 1> resTypes;
2313   argTypes.reserve(fnType.getNumParams());
2314   for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2315     argTypes.push_back(fnType.getParamType(i));
2316 
2317   Type returnType = fnType.getReturnType();
2318   if (!returnType.isa<LLVMVoidType>())
2319     resTypes.push_back(returnType);
2320 
2321   function_interface_impl::printFunctionSignature(p, *this, argTypes,
2322                                                   isVarArg(), resTypes);
2323   function_interface_impl::printFunctionAttributes(
2324       p, *this, argTypes.size(), resTypes.size(),
2325       {getLinkageAttrName(), getCConvAttrName()});
2326 
2327   // Print the body if this is not an external function.
2328   Region &body = getBody();
2329   if (!body.empty()) {
2330     p << ' ';
2331     p.printRegion(body, /*printEntryBlockArgs=*/false,
2332                   /*printBlockTerminators=*/true);
2333   }
2334 }
2335 
2336 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2337 // - functions don't have 'common' linkage
2338 // - external functions have 'external' or 'extern_weak' linkage;
2339 // - vararg is (currently) only supported for external functions;
verify()2340 LogicalResult LLVMFuncOp::verify() {
2341   if (getLinkage() == LLVM::Linkage::Common)
2342     return emitOpError() << "functions cannot have '"
2343                          << stringifyLinkage(LLVM::Linkage::Common)
2344                          << "' linkage";
2345 
2346   // Check to see if this function has a void return with a result attribute to
2347   // it. It isn't clear what semantics we would assign to that.
2348   if (getFunctionType().getReturnType().isa<LLVMVoidType>() &&
2349       !getResultAttrs(0).empty()) {
2350     return emitOpError()
2351            << "cannot attach result attributes to functions with a void return";
2352   }
2353 
2354   if (isExternal()) {
2355     if (getLinkage() != LLVM::Linkage::External &&
2356         getLinkage() != LLVM::Linkage::ExternWeak)
2357       return emitOpError() << "external functions must have '"
2358                            << stringifyLinkage(LLVM::Linkage::External)
2359                            << "' or '"
2360                            << stringifyLinkage(LLVM::Linkage::ExternWeak)
2361                            << "' linkage";
2362     return success();
2363   }
2364 
2365   return success();
2366 }
2367 
2368 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2369 /// - entry block arguments are of LLVM types.
verifyRegions()2370 LogicalResult LLVMFuncOp::verifyRegions() {
2371   if (isExternal())
2372     return success();
2373 
2374   unsigned numArguments = getFunctionType().getNumParams();
2375   Block &entryBlock = front();
2376   for (unsigned i = 0; i < numArguments; ++i) {
2377     Type argType = entryBlock.getArgument(i).getType();
2378     if (!isCompatibleType(argType))
2379       return emitOpError("entry block argument #")
2380              << i << " is not of LLVM type";
2381   }
2382 
2383   return success();
2384 }
2385 
2386 //===----------------------------------------------------------------------===//
2387 // Verification for LLVM::ConstantOp.
2388 //===----------------------------------------------------------------------===//
2389 
verify()2390 LogicalResult LLVM::ConstantOp::verify() {
2391   if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
2392     auto arrayType = getType().dyn_cast<LLVMArrayType>();
2393     if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
2394         !arrayType.getElementType().isInteger(8)) {
2395       return emitOpError() << "expected array type of "
2396                            << sAttr.getValue().size()
2397                            << " i8 elements for the string constant";
2398     }
2399     return success();
2400   }
2401   if (auto structType = getType().dyn_cast<LLVMStructType>()) {
2402     if (structType.getBody().size() != 2 ||
2403         structType.getBody()[0] != structType.getBody()[1]) {
2404       return emitError() << "expected struct type with two elements of the "
2405                             "same type, the type of a complex constant";
2406     }
2407 
2408     auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
2409     if (!arrayAttr || arrayAttr.size() != 2 ||
2410         arrayAttr[0].getType() != arrayAttr[1].getType()) {
2411       return emitOpError() << "expected array attribute with two elements, "
2412                               "representing a complex constant";
2413     }
2414 
2415     Type elementType = structType.getBody()[0];
2416     if (!elementType
2417              .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
2418       return emitError()
2419              << "expected struct element types to be floating point type or "
2420                 "integer type";
2421     }
2422     return success();
2423   }
2424   if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
2425     return emitOpError()
2426            << "only supports integer, float, string or elements attributes";
2427   return success();
2428 }
2429 
2430 // Constant op constant-folds to its value.
fold(ArrayRef<Attribute>)2431 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); }
2432 
2433 //===----------------------------------------------------------------------===//
2434 // Utility functions for parsing atomic ops
2435 //===----------------------------------------------------------------------===//
2436 
2437 // Helper function to parse a keyword into the specified attribute named by
2438 // `attrName`. The keyword must match one of the string values defined by the
2439 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
2440 // state.
parseAtomicBinOp(OpAsmParser & parser,OperationState & result,StringRef attrName)2441 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
2442                                     StringRef attrName) {
2443   SMLoc loc;
2444   StringRef keyword;
2445   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
2446     return failure();
2447 
2448   // Replace the keyword `keyword` with an integer attribute.
2449   auto kind = symbolizeAtomicBinOp(keyword);
2450   if (!kind) {
2451     return parser.emitError(loc)
2452            << "'" << keyword << "' is an incorrect value of the '" << attrName
2453            << "' attribute";
2454   }
2455 
2456   auto value = static_cast<int64_t>(*kind);
2457   auto attr = parser.getBuilder().getI64IntegerAttr(value);
2458   result.addAttribute(attrName, attr);
2459 
2460   return success();
2461 }
2462 
2463 // Helper function to parse a keyword into the specified attribute named by
2464 // `attrName`. The keyword must match one of the string values defined by the
2465 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
2466 // state.
parseAtomicOrdering(OpAsmParser & parser,OperationState & result,StringRef attrName)2467 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
2468                                        OperationState &result,
2469                                        StringRef attrName) {
2470   SMLoc loc;
2471   StringRef ordering;
2472   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
2473     return failure();
2474 
2475   // Replace the keyword `ordering` with an integer attribute.
2476   auto kind = symbolizeAtomicOrdering(ordering);
2477   if (!kind) {
2478     return parser.emitError(loc)
2479            << "'" << ordering << "' is an incorrect value of the '" << attrName
2480            << "' attribute";
2481   }
2482 
2483   auto value = static_cast<int64_t>(*kind);
2484   auto attr = parser.getBuilder().getI64IntegerAttr(value);
2485   result.addAttribute(attrName, attr);
2486 
2487   return success();
2488 }
2489 
2490 //===----------------------------------------------------------------------===//
2491 // Printer, parser and verifier for LLVM::AtomicRMWOp.
2492 //===----------------------------------------------------------------------===//
2493 
print(OpAsmPrinter & p)2494 void AtomicRMWOp::print(OpAsmPrinter &p) {
2495   p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", "
2496     << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' ';
2497   p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"});
2498   p << " : " << getRes().getType();
2499 }
2500 
2501 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
2502 //                 attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)2503 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) {
2504   Type type;
2505   OpAsmParser::UnresolvedOperand ptr, val;
2506   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
2507       parser.parseComma() || parser.parseOperand(val) ||
2508       parseAtomicOrdering(parser, result, "ordering") ||
2509       parser.parseOptionalAttrDict(result.attributes) ||
2510       parser.parseColonType(type) ||
2511       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2512                             result.operands) ||
2513       parser.resolveOperand(val, type, result.operands))
2514     return failure();
2515 
2516   result.addTypes(type);
2517   return success();
2518 }
2519 
verify()2520 LogicalResult AtomicRMWOp::verify() {
2521   auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2522   auto valType = getVal().getType();
2523   if (valType != ptrType.getElementType())
2524     return emitOpError("expected LLVM IR element type for operand #0 to "
2525                        "match type for operand #1");
2526   auto resType = getRes().getType();
2527   if (resType != valType)
2528     return emitOpError(
2529         "expected LLVM IR result type to match type for operand #1");
2530   if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) {
2531     if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2532       return emitOpError("expected LLVM IR floating point type");
2533   } else if (getBinOp() == AtomicBinOp::xchg) {
2534     auto intType = valType.dyn_cast<IntegerType>();
2535     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2536     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2537         intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
2538         !valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
2539         !valType.isa<Float64Type>())
2540       return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2541   } else {
2542     auto intType = valType.dyn_cast<IntegerType>();
2543     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2544     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2545         intBitWidth != 64)
2546       return emitOpError("expected LLVM IR integer type");
2547   }
2548 
2549   if (static_cast<unsigned>(getOrdering()) <
2550       static_cast<unsigned>(AtomicOrdering::monotonic))
2551     return emitOpError() << "expected at least '"
2552                          << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2553                          << "' ordering";
2554 
2555   return success();
2556 }
2557 
2558 //===----------------------------------------------------------------------===//
2559 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
2560 //===----------------------------------------------------------------------===//
2561 
print(OpAsmPrinter & p)2562 void AtomicCmpXchgOp::print(OpAsmPrinter &p) {
2563   p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' '
2564     << stringifyAtomicOrdering(getSuccessOrdering()) << ' '
2565     << stringifyAtomicOrdering(getFailureOrdering());
2566   p.printOptionalAttrDict((*this)->getAttrs(),
2567                           {"success_ordering", "failure_ordering"});
2568   p << " : " << getVal().getType();
2569 }
2570 
2571 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
2572 //                 keyword keyword attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)2573 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser,
2574                                    OperationState &result) {
2575   auto &builder = parser.getBuilder();
2576   Type type;
2577   OpAsmParser::UnresolvedOperand ptr, cmp, val;
2578   if (parser.parseOperand(ptr) || parser.parseComma() ||
2579       parser.parseOperand(cmp) || parser.parseComma() ||
2580       parser.parseOperand(val) ||
2581       parseAtomicOrdering(parser, result, "success_ordering") ||
2582       parseAtomicOrdering(parser, result, "failure_ordering") ||
2583       parser.parseOptionalAttrDict(result.attributes) ||
2584       parser.parseColonType(type) ||
2585       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2586                             result.operands) ||
2587       parser.resolveOperand(cmp, type, result.operands) ||
2588       parser.resolveOperand(val, type, result.operands))
2589     return failure();
2590 
2591   auto boolType = IntegerType::get(builder.getContext(), 1);
2592   auto resultType =
2593       LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
2594   result.addTypes(resultType);
2595 
2596   return success();
2597 }
2598 
verify()2599 LogicalResult AtomicCmpXchgOp::verify() {
2600   auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2601   if (!ptrType)
2602     return emitOpError("expected LLVM IR pointer type for operand #0");
2603   auto cmpType = getCmp().getType();
2604   auto valType = getVal().getType();
2605   if (cmpType != ptrType.getElementType() || cmpType != valType)
2606     return emitOpError("expected LLVM IR element type for operand #0 to "
2607                        "match type for all other operands");
2608   auto intType = valType.dyn_cast<IntegerType>();
2609   unsigned intBitWidth = intType ? intType.getWidth() : 0;
2610   if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
2611       intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
2612       !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
2613       !valType.isa<Float32Type>() && !valType.isa<Float64Type>())
2614     return emitOpError("unexpected LLVM IR type");
2615   if (getSuccessOrdering() < AtomicOrdering::monotonic ||
2616       getFailureOrdering() < AtomicOrdering::monotonic)
2617     return emitOpError("ordering must be at least 'monotonic'");
2618   if (getFailureOrdering() == AtomicOrdering::release ||
2619       getFailureOrdering() == AtomicOrdering::acq_rel)
2620     return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2621   return success();
2622 }
2623 
2624 //===----------------------------------------------------------------------===//
2625 // Printer, parser and verifier for LLVM::FenceOp.
2626 //===----------------------------------------------------------------------===//
2627 
2628 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
2629 // attribute-dict?
parse(OpAsmParser & parser,OperationState & result)2630 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) {
2631   StringAttr sScope;
2632   StringRef syncscopeKeyword = "syncscope";
2633   if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
2634     if (parser.parseLParen() ||
2635         parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
2636         parser.parseRParen())
2637       return failure();
2638   } else {
2639     result.addAttribute(syncscopeKeyword,
2640                         parser.getBuilder().getStringAttr(""));
2641   }
2642   if (parseAtomicOrdering(parser, result, "ordering") ||
2643       parser.parseOptionalAttrDict(result.attributes))
2644     return failure();
2645   return success();
2646 }
2647 
print(OpAsmPrinter & p)2648 void FenceOp::print(OpAsmPrinter &p) {
2649   StringRef syncscopeKeyword = "syncscope";
2650   p << ' ';
2651   if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
2652     p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") ";
2653   p << stringifyAtomicOrdering(getOrdering());
2654 }
2655 
verify()2656 LogicalResult FenceOp::verify() {
2657   if (getOrdering() == AtomicOrdering::not_atomic ||
2658       getOrdering() == AtomicOrdering::unordered ||
2659       getOrdering() == AtomicOrdering::monotonic)
2660     return emitOpError("can be given only acquire, release, acq_rel, "
2661                        "and seq_cst orderings");
2662   return success();
2663 }
2664 
2665 //===----------------------------------------------------------------------===//
2666 // Folder for LLVM::BitcastOp
2667 //===----------------------------------------------------------------------===//
2668 
fold(ArrayRef<Attribute> operands)2669 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
2670   // bitcast(x : T0, T0) -> x
2671   if (getArg().getType() == getType())
2672     return getArg();
2673   // bitcast(bitcast(x : T0, T1), T0) -> x
2674   if (auto prev = getArg().getDefiningOp<BitcastOp>())
2675     if (prev.getArg().getType() == getType())
2676       return prev.getArg();
2677   return {};
2678 }
2679 
2680 //===----------------------------------------------------------------------===//
2681 // Folder for LLVM::AddrSpaceCastOp
2682 //===----------------------------------------------------------------------===//
2683 
fold(ArrayRef<Attribute> operands)2684 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
2685   // addrcast(x : T0, T0) -> x
2686   if (getArg().getType() == getType())
2687     return getArg();
2688   // addrcast(addrcast(x : T0, T1), T0) -> x
2689   if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
2690     if (prev.getArg().getType() == getType())
2691       return prev.getArg();
2692   return {};
2693 }
2694 
2695 //===----------------------------------------------------------------------===//
2696 // Folder for LLVM::GEPOp
2697 //===----------------------------------------------------------------------===//
2698 
fold(ArrayRef<Attribute> operands)2699 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
2700   // gep %x:T, 0 -> %x
2701   if (getBase().getType() == getType() && getIndices().size() == 1 &&
2702       matchPattern(getIndices()[0], m_Zero()))
2703     return getBase();
2704   return {};
2705 }
2706 
2707 //===----------------------------------------------------------------------===//
2708 // LLVMDialect initialization, type parsing, and registration.
2709 //===----------------------------------------------------------------------===//
2710 
initialize()2711 void LLVMDialect::initialize() {
2712   addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
2713 
2714   // clang-format off
2715   addTypes<LLVMVoidType,
2716            LLVMPPCFP128Type,
2717            LLVMX86MMXType,
2718            LLVMTokenType,
2719            LLVMLabelType,
2720            LLVMMetadataType,
2721            LLVMFunctionType,
2722            LLVMPointerType,
2723            LLVMFixedVectorType,
2724            LLVMScalableVectorType,
2725            LLVMArrayType,
2726            LLVMStructType>();
2727   // clang-format on
2728   addOperations<
2729 #define GET_OP_LIST
2730 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2731       ,
2732 #define GET_OP_LIST
2733 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
2734       >();
2735 
2736   // Support unknown operations because not all LLVM operations are registered.
2737   allowUnknownOperations();
2738 }
2739 
2740 #define GET_OP_CLASSES
2741 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2742 
2743 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const2744 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
2745   return detail::parseType(parser);
2746 }
2747 
2748 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const2749 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
2750   return detail::printType(type, os);
2751 }
2752 
verifyDataLayoutString(StringRef descr,llvm::function_ref<void (const Twine &)> reportError)2753 LogicalResult LLVMDialect::verifyDataLayoutString(
2754     StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
2755   llvm::Expected<llvm::DataLayout> maybeDataLayout =
2756       llvm::DataLayout::parse(descr);
2757   if (maybeDataLayout)
2758     return success();
2759 
2760   std::string message;
2761   llvm::raw_string_ostream messageStream(message);
2762   llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
2763   reportError("invalid data layout descriptor: " + messageStream.str());
2764   return failure();
2765 }
2766 
2767 /// Verify LLVM dialect attributes.
verifyOperationAttribute(Operation * op,NamedAttribute attr)2768 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
2769                                                     NamedAttribute attr) {
2770   // If the `llvm.loop` attribute is present, enforce the following structure,
2771   // which the module translation can assume.
2772   if (attr.getName() == LLVMDialect::getLoopAttrName()) {
2773     auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>();
2774     if (!loopAttr)
2775       return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
2776                                << "' to be a dictionary attribute";
2777     Optional<NamedAttribute> parallelAccessGroup =
2778         loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
2779     if (parallelAccessGroup) {
2780       auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>();
2781       if (!accessGroups)
2782         return op->emitOpError()
2783                << "expected '" << LLVMDialect::getParallelAccessAttrName()
2784                << "' to be an array attribute";
2785       for (Attribute attr : accessGroups) {
2786         auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
2787         if (!accessGroupRef)
2788           return op->emitOpError()
2789                  << "expected '" << attr << "' to be a symbol reference";
2790         StringAttr metadataName = accessGroupRef.getRootReference();
2791         auto metadataOp =
2792             SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
2793                 op->getParentOp(), metadataName);
2794         if (!metadataOp)
2795           return op->emitOpError()
2796                  << "expected '" << attr << "' to reference a metadata op";
2797         StringAttr accessGroupName = accessGroupRef.getLeafReference();
2798         Operation *accessGroupOp =
2799             SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
2800         if (!accessGroupOp)
2801           return op->emitOpError()
2802                  << "expected '" << attr << "' to reference an access_group op";
2803       }
2804     }
2805 
2806     Optional<NamedAttribute> loopOptions =
2807         loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
2808     if (loopOptions && !loopOptions->getValue().isa<LoopOptionsAttr>())
2809       return op->emitOpError()
2810              << "expected '" << LLVMDialect::getLoopOptionsAttrName()
2811              << "' to be a `loopopts` attribute";
2812   }
2813 
2814   if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2815     return op->emitOpError()
2816            << "'" << LLVM::LLVMDialect::getStructAttrsAttrName()
2817            << "' is permitted only in argument or result attributes";
2818   }
2819 
2820   // If the data layout attribute is present, it must use the LLVM data layout
2821   // syntax. Try parsing it and report errors in case of failure. Users of this
2822   // attribute may assume it is well-formed and can pass it to the (asserting)
2823   // llvm::DataLayout constructor.
2824   if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
2825     return success();
2826   if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>())
2827     return verifyDataLayoutString(
2828         stringAttr.getValue(),
2829         [op](const Twine &message) { op->emitOpError() << message.str(); });
2830 
2831   return op->emitOpError() << "expected '"
2832                            << LLVM::LLVMDialect::getDataLayoutAttrName()
2833                            << "' to be a string attributes";
2834 }
2835 
verifyStructAttr(Operation * op,Attribute attr,Type annotatedType)2836 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr,
2837                                             Type annotatedType) {
2838   auto structType = annotatedType.dyn_cast<LLVMStructType>();
2839   if (!structType) {
2840     const auto emitIncorrectAnnotatedType = [&op]() {
2841       return op->emitError()
2842              << "expected '" << LLVMDialect::getStructAttrsAttrName()
2843              << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'";
2844     };
2845     const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>();
2846     if (!ptrType)
2847       return emitIncorrectAnnotatedType();
2848     structType = ptrType.getElementType().dyn_cast<LLVMStructType>();
2849     if (!structType)
2850       return emitIncorrectAnnotatedType();
2851   }
2852 
2853   const auto arrAttrs = attr.dyn_cast<ArrayAttr>();
2854   if (!arrAttrs)
2855     return op->emitError() << "expected '"
2856                            << LLVMDialect::getStructAttrsAttrName()
2857                            << "' to be an array attribute";
2858 
2859   if (structType.getBody().size() != arrAttrs.size())
2860     return op->emitError()
2861            << "size of '" << LLVMDialect::getStructAttrsAttrName()
2862            << "' must match the size of the annotated '!llvm.struct'";
2863   return success();
2864 }
2865 
verifyFuncOpInterfaceStructAttr(Operation * op,Attribute attr,const std::function<Type (FunctionOpInterface)> & getAnnotatedType)2866 static LogicalResult verifyFuncOpInterfaceStructAttr(
2867     Operation *op, Attribute attr,
2868     const std::function<Type(FunctionOpInterface)> &getAnnotatedType) {
2869   if (auto funcOp = dyn_cast<FunctionOpInterface>(op))
2870     return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp));
2871   return op->emitError() << "expected '"
2872                          << LLVMDialect::getStructAttrsAttrName()
2873                          << "' to be used on function-like operations";
2874 }
2875 
2876 /// Verify LLVMIR function argument attributes.
verifyRegionArgAttribute(Operation * op,unsigned regionIdx,unsigned argIdx,NamedAttribute argAttr)2877 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
2878                                                     unsigned regionIdx,
2879                                                     unsigned argIdx,
2880                                                     NamedAttribute argAttr) {
2881   // Check that llvm.noalias is a unit attribute.
2882   if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() &&
2883       !argAttr.getValue().isa<UnitAttr>())
2884     return op->emitError()
2885            << "expected llvm.noalias argument attribute to be a unit attribute";
2886   // Check that llvm.align is an integer attribute.
2887   if (argAttr.getName() == LLVMDialect::getAlignAttrName() &&
2888       !argAttr.getValue().isa<IntegerAttr>())
2889     return op->emitError()
2890            << "llvm.align argument attribute of non integer type";
2891   if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2892     return verifyFuncOpInterfaceStructAttr(
2893         op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) {
2894           return funcOp.getArgumentTypes()[argIdx];
2895         });
2896   }
2897   return success();
2898 }
2899 
verifyRegionResultAttribute(Operation * op,unsigned regionIdx,unsigned resIdx,NamedAttribute resAttr)2900 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
2901                                                        unsigned regionIdx,
2902                                                        unsigned resIdx,
2903                                                        NamedAttribute resAttr) {
2904   if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2905     return verifyFuncOpInterfaceStructAttr(
2906         op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) {
2907           return funcOp.getResultTypes()[resIdx];
2908         });
2909   }
2910   return success();
2911 }
2912 
2913 //===----------------------------------------------------------------------===//
2914 // Utility functions.
2915 //===----------------------------------------------------------------------===//
2916 
createGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,LLVM::Linkage linkage)2917 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
2918                                      StringRef name, StringRef value,
2919                                      LLVM::Linkage linkage) {
2920   assert(builder.getInsertionBlock() &&
2921          builder.getInsertionBlock()->getParentOp() &&
2922          "expected builder to point to a block constrained in an op");
2923   auto module =
2924       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
2925   assert(module && "builder points to an op outside of a module");
2926 
2927   // Create the global at the entry of the module.
2928   OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
2929   MLIRContext *ctx = builder.getContext();
2930   auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
2931   auto global = moduleBuilder.create<LLVM::GlobalOp>(
2932       loc, type, /*isConstant=*/true, linkage, name,
2933       builder.getStringAttr(value), /*alignment=*/0);
2934 
2935   // Get the pointer to the first character in the global string.
2936   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
2937   Value cst0 = builder.create<LLVM::ConstantOp>(
2938       loc, IntegerType::get(ctx, 64),
2939       builder.getIntegerAttr(builder.getIndexType(), 0));
2940   return builder.create<LLVM::GEPOp>(
2941       loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
2942       ValueRange{cst0, cst0});
2943 }
2944 
satisfiesLLVMModule(Operation * op)2945 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
2946   return op->hasTrait<OpTrait::SymbolTable>() &&
2947          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
2948 }
2949 
print(AsmPrinter & printer) const2950 void FMFAttr::print(AsmPrinter &printer) const {
2951   printer << "<";
2952   printer << stringifyFastmathFlags(this->getFlags());
2953   printer << ">";
2954 }
2955 
parse(AsmParser & parser,Type type)2956 Attribute FMFAttr::parse(AsmParser &parser, Type type) {
2957   if (failed(parser.parseLess()))
2958     return {};
2959 
2960   FastmathFlags flags = {};
2961   if (failed(parser.parseOptionalGreater())) {
2962     auto parseFlags = [&]() -> ParseResult {
2963       StringRef elemName;
2964       if (failed(parser.parseKeyword(&elemName)))
2965         return failure();
2966 
2967       auto elem = symbolizeFastmathFlags(elemName);
2968       if (!elem)
2969         return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
2970                << elemName;
2971 
2972       flags = flags | *elem;
2973       return success();
2974     };
2975     if (failed(parser.parseCommaSeparatedList(parseFlags)) ||
2976         parser.parseGreater())
2977       return {};
2978   }
2979 
2980   return FMFAttr::get(parser.getContext(), flags);
2981 }
2982 
print(AsmPrinter & printer) const2983 void LinkageAttr::print(AsmPrinter &printer) const {
2984   printer << "<";
2985   if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
2986     printer << stringifyEnum(getLinkage());
2987   else
2988     printer << static_cast<uint64_t>(getLinkage());
2989   printer << ">";
2990 }
2991 
parse(AsmParser & parser,Type type)2992 Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
2993   StringRef elemName;
2994   if (parser.parseLess() || parser.parseKeyword(&elemName) ||
2995       parser.parseGreater())
2996     return {};
2997   auto elem = linkage::symbolizeLinkage(elemName);
2998   if (!elem) {
2999     parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
3000     return {};
3001   }
3002   Linkage linkage = *elem;
3003   return LinkageAttr::get(parser.getContext(), linkage);
3004 }
3005 
print(AsmPrinter & printer) const3006 void CConvAttr::print(AsmPrinter &printer) const {
3007   printer << "<";
3008   if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv())
3009     printer << stringifyEnum(getCallingConv());
3010   else
3011     printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv());
3012   printer << ">";
3013 }
3014 
parse(AsmParser & parser,Type type)3015 Attribute CConvAttr::parse(AsmParser &parser, Type type) {
3016   StringRef convName;
3017 
3018   if (parser.parseLess() || parser.parseKeyword(&convName) ||
3019       parser.parseGreater())
3020     return {};
3021   auto cconv = cconv::symbolizeCConv(convName);
3022   if (!cconv) {
3023     parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
3024         << convName;
3025     return {};
3026   }
3027   CConv cconvVal = *cconv;
3028   return CConvAttr::get(parser.getContext(), cconvVal);
3029 }
3030 
LoopOptionsAttrBuilder(LoopOptionsAttr attr)3031 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
3032     : options(attr.getOptions().begin(), attr.getOptions().end()) {}
3033 
3034 template <typename T>
setOption(LoopOptionCase tag,Optional<T> value)3035 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
3036                                                           Optional<T> value) {
3037   auto option = llvm::find_if(
3038       options, [tag](auto option) { return option.first == tag; });
3039   if (option != options.end()) {
3040     if (value)
3041       option->second = *value;
3042     else
3043       options.erase(option);
3044   } else {
3045     options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
3046   }
3047   return *this;
3048 }
3049 
3050 LoopOptionsAttrBuilder &
setDisableLICM(Optional<bool> value)3051 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
3052   return setOption(LoopOptionCase::disable_licm, value);
3053 }
3054 
3055 /// Set the `interleave_count` option to the provided value. If no value
3056 /// is provided the option is deleted.
3057 LoopOptionsAttrBuilder &
setInterleaveCount(Optional<uint64_t> count)3058 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
3059   return setOption(LoopOptionCase::interleave_count, count);
3060 }
3061 
3062 /// Set the `disable_unroll` option to the provided value. If no value
3063 /// is provided the option is deleted.
3064 LoopOptionsAttrBuilder &
setDisableUnroll(Optional<bool> value)3065 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
3066   return setOption(LoopOptionCase::disable_unroll, value);
3067 }
3068 
3069 /// Set the `disable_pipeline` option to the provided value. If no value
3070 /// is provided the option is deleted.
3071 LoopOptionsAttrBuilder &
setDisablePipeline(Optional<bool> value)3072 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
3073   return setOption(LoopOptionCase::disable_pipeline, value);
3074 }
3075 
3076 /// Set the `pipeline_initiation_interval` option to the provided value.
3077 /// If no value is provided the option is deleted.
setPipelineInitiationInterval(Optional<uint64_t> count)3078 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
3079     Optional<uint64_t> count) {
3080   return setOption(LoopOptionCase::pipeline_initiation_interval, count);
3081 }
3082 
3083 template <typename T>
3084 static Optional<T>
getOption(ArrayRef<std::pair<LoopOptionCase,int64_t>> options,LoopOptionCase option)3085 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
3086           LoopOptionCase option) {
3087   auto it =
3088       lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
3089         return optionPair.first < option;
3090       });
3091   if (it == options.end())
3092     return {};
3093   return static_cast<T>(it->second);
3094 }
3095 
disableUnroll()3096 Optional<bool> LoopOptionsAttr::disableUnroll() {
3097   return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
3098 }
3099 
disableLICM()3100 Optional<bool> LoopOptionsAttr::disableLICM() {
3101   return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
3102 }
3103 
interleaveCount()3104 Optional<int64_t> LoopOptionsAttr::interleaveCount() {
3105   return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
3106 }
3107 
3108 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,ArrayRef<std::pair<LoopOptionCase,int64_t>> sortedOptions)3109 LoopOptionsAttr LoopOptionsAttr::get(
3110     MLIRContext *context,
3111     ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
3112   assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
3113          "LoopOptionsAttr ctor expects a sorted options array");
3114   return Base::get(context, sortedOptions);
3115 }
3116 
3117 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,LoopOptionsAttrBuilder & optionBuilders)3118 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
3119                                      LoopOptionsAttrBuilder &optionBuilders) {
3120   llvm::sort(optionBuilders.options, llvm::less_first());
3121   return Base::get(context, optionBuilders.options);
3122 }
3123 
print(AsmPrinter & printer) const3124 void LoopOptionsAttr::print(AsmPrinter &printer) const {
3125   printer << "<";
3126   llvm::interleaveComma(getOptions(), printer, [&](auto option) {
3127     printer << stringifyEnum(option.first) << " = ";
3128     switch (option.first) {
3129     case LoopOptionCase::disable_licm:
3130     case LoopOptionCase::disable_unroll:
3131     case LoopOptionCase::disable_pipeline:
3132       printer << (option.second ? "true" : "false");
3133       break;
3134     case LoopOptionCase::interleave_count:
3135     case LoopOptionCase::pipeline_initiation_interval:
3136       printer << option.second;
3137       break;
3138     }
3139   });
3140   printer << ">";
3141 }
3142 
parse(AsmParser & parser,Type type)3143 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
3144   if (failed(parser.parseLess()))
3145     return {};
3146 
3147   SmallVector<std::pair<LoopOptionCase, int64_t>> options;
3148   llvm::SmallDenseSet<LoopOptionCase> seenOptions;
3149   auto parseLoopOptions = [&]() -> ParseResult {
3150     StringRef optionName;
3151     if (parser.parseKeyword(&optionName))
3152       return failure();
3153 
3154     auto option = symbolizeLoopOptionCase(optionName);
3155     if (!option)
3156       return parser.emitError(parser.getNameLoc(), "unknown loop option: ")
3157              << optionName;
3158     if (!seenOptions.insert(*option).second)
3159       return parser.emitError(parser.getNameLoc(), "loop option present twice");
3160     if (failed(parser.parseEqual()))
3161       return failure();
3162 
3163     int64_t value;
3164     switch (*option) {
3165     case LoopOptionCase::disable_licm:
3166     case LoopOptionCase::disable_unroll:
3167     case LoopOptionCase::disable_pipeline:
3168       if (succeeded(parser.parseOptionalKeyword("true")))
3169         value = 1;
3170       else if (succeeded(parser.parseOptionalKeyword("false")))
3171         value = 0;
3172       else {
3173         return parser.emitError(parser.getNameLoc(),
3174                                 "expected boolean value 'true' or 'false'");
3175       }
3176       break;
3177     case LoopOptionCase::interleave_count:
3178     case LoopOptionCase::pipeline_initiation_interval:
3179       if (failed(parser.parseInteger(value)))
3180         return parser.emitError(parser.getNameLoc(), "expected integer value");
3181       break;
3182     }
3183     options.push_back(std::make_pair(*option, value));
3184     return success();
3185   };
3186   if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater())
3187     return {};
3188 
3189   llvm::sort(options, llvm::less_first());
3190   return get(parser.getContext(), options);
3191 }
3192