1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "TypeDetail.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Matchers.h"
23 
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/Bitcode/BitcodeReader.h"
28 #include "llvm/Bitcode/BitcodeWriter.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Mutex.h"
33 #include "llvm/Support/SourceMgr.h"
34 
35 #include <numeric>
36 
37 using namespace mlir;
38 using namespace mlir::LLVM;
39 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
40 
41 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
42 
43 static constexpr const char kVolatileAttrName[] = "volatile_";
44 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
45 
46 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
47 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
48 #define GET_ATTRDEF_CLASSES
49 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
50 
51 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
52   SmallVector<NamedAttribute, 8> filteredAttrs(
53       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
54         if (attr.getName() == "fastmathFlags") {
55           auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
56           return defAttr != attr.getValue();
57         }
58         return true;
59       }));
60   return filteredAttrs;
61 }
62 
63 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
64                                     NamedAttrList &result) {
65   return parser.parseOptionalAttrDict(result);
66 }
67 
68 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
69                              DictionaryAttr attrs) {
70   printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
71 }
72 
73 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
74 /// fully defined llvm.func.
75 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
76                                          Operation *op,
77                                          SymbolTableCollection &symbolTable) {
78   StringRef name = symbol.getValue();
79   auto func =
80       symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
81   if (!func)
82     return op->emitOpError("'")
83            << name << "' does not reference a valid LLVM function";
84   if (func.isExternal())
85     return op->emitOpError("'") << name << "' does not have a definition";
86   return success();
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Printing/parsing for LLVM::CmpOp.
91 //===----------------------------------------------------------------------===//
92 
93 void ICmpOp::print(OpAsmPrinter &p) {
94   p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
95     << ", " << getOperand(1);
96   p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
97   p << " : " << getLhs().getType();
98 }
99 
100 void FCmpOp::print(OpAsmPrinter &p) {
101   p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
102     << ", " << getOperand(1);
103   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
104   p << " : " << getLhs().getType();
105 }
106 
107 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
108 //                 attribute-dict? `:` type
109 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
110 //                 attribute-dict? `:` type
111 template <typename CmpPredicateType>
112 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
113   Builder &builder = parser.getBuilder();
114 
115   StringAttr predicateAttr;
116   OpAsmParser::UnresolvedOperand lhs, rhs;
117   Type type;
118   SMLoc predicateLoc, trailingTypeLoc;
119   if (parser.getCurrentLocation(&predicateLoc) ||
120       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
121       parser.parseOperand(lhs) || parser.parseComma() ||
122       parser.parseOperand(rhs) ||
123       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
124       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
125       parser.resolveOperand(lhs, type, result.operands) ||
126       parser.resolveOperand(rhs, type, result.operands))
127     return failure();
128 
129   // Replace the string attribute `predicate` with an integer attribute.
130   int64_t predicateValue = 0;
131   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
132     Optional<ICmpPredicate> predicate =
133         symbolizeICmpPredicate(predicateAttr.getValue());
134     if (!predicate)
135       return parser.emitError(predicateLoc)
136              << "'" << predicateAttr.getValue()
137              << "' is an incorrect value of the 'predicate' attribute";
138     predicateValue = static_cast<int64_t>(predicate.getValue());
139   } else {
140     Optional<FCmpPredicate> predicate =
141         symbolizeFCmpPredicate(predicateAttr.getValue());
142     if (!predicate)
143       return parser.emitError(predicateLoc)
144              << "'" << predicateAttr.getValue()
145              << "' is an incorrect value of the 'predicate' attribute";
146     predicateValue = static_cast<int64_t>(predicate.getValue());
147   }
148 
149   result.attributes.set("predicate",
150                         parser.getBuilder().getI64IntegerAttr(predicateValue));
151 
152   // The result type is either i1 or a vector type <? x i1> if the inputs are
153   // vectors.
154   Type resultType = IntegerType::get(builder.getContext(), 1);
155   if (!isCompatibleType(type))
156     return parser.emitError(trailingTypeLoc,
157                             "expected LLVM dialect-compatible type");
158   if (LLVM::isCompatibleVectorType(type)) {
159     if (LLVM::isScalableVectorType(type)) {
160       resultType = LLVM::getVectorType(
161           resultType, LLVM::getVectorNumElements(type).getKnownMinValue(),
162           /*isScalable=*/true);
163     } else {
164       resultType = LLVM::getVectorType(
165           resultType, LLVM::getVectorNumElements(type).getFixedValue(),
166           /*isScalable=*/false);
167     }
168   }
169 
170   result.addTypes({resultType});
171   return success();
172 }
173 
174 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
175   return parseCmpOp<ICmpPredicate>(parser, result);
176 }
177 
178 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
179   return parseCmpOp<FCmpPredicate>(parser, result);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // Printing/parsing for LLVM::AllocaOp.
184 //===----------------------------------------------------------------------===//
185 
186 void AllocaOp::print(OpAsmPrinter &p) {
187   auto elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType();
188 
189   auto funcTy =
190       FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
191 
192   p << ' ' << getArraySize() << " x " << elemTy;
193   if (getAlignment().hasValue() && *getAlignment() != 0)
194     p.printOptionalAttrDict((*this)->getAttrs());
195   else
196     p.printOptionalAttrDict((*this)->getAttrs(), {"alignment"});
197   p << " : " << funcTy;
198 }
199 
200 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
201 //                 `:` type `,` type
202 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
203   OpAsmParser::UnresolvedOperand arraySize;
204   Type type, elemType;
205   SMLoc trailingTypeLoc;
206   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
207       parser.parseType(elemType) ||
208       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
209       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
210     return failure();
211 
212   Optional<NamedAttribute> alignmentAttr =
213       result.attributes.getNamed("alignment");
214   if (alignmentAttr.hasValue()) {
215     auto alignmentInt =
216         alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>();
217     if (!alignmentInt)
218       return parser.emitError(parser.getNameLoc(),
219                               "expected integer alignment");
220     if (alignmentInt.getValue().isNullValue())
221       result.attributes.erase("alignment");
222   }
223 
224   // Extract the result type from the trailing function type.
225   auto funcType = type.dyn_cast<FunctionType>();
226   if (!funcType || funcType.getNumInputs() != 1 ||
227       funcType.getNumResults() != 1)
228     return parser.emitError(
229         trailingTypeLoc,
230         "expected trailing function type with one argument and one result");
231 
232   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
233     return failure();
234 
235   result.addTypes({funcType.getResult(0)});
236   return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // LLVM::BrOp
241 //===----------------------------------------------------------------------===//
242 
243 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
244   assert(index == 0 && "invalid successor index");
245   return SuccessorOperands(getDestOperandsMutable());
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // LLVM::CondBrOp
250 //===----------------------------------------------------------------------===//
251 
252 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
253   assert(index < getNumSuccessors() && "invalid successor index");
254   return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
255                                       : getFalseDestOperandsMutable());
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // LLVM::SwitchOp
260 //===----------------------------------------------------------------------===//
261 
262 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
263                      Block *defaultDestination, ValueRange defaultOperands,
264                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
265                      ArrayRef<ValueRange> caseOperands,
266                      ArrayRef<int32_t> branchWeights) {
267   ElementsAttr caseValuesAttr;
268   if (!caseValues.empty())
269     caseValuesAttr = builder.getI32VectorAttr(caseValues);
270 
271   ElementsAttr weightsAttr;
272   if (!branchWeights.empty())
273     weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
274 
275   build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
276         weightsAttr, defaultDestination, caseDestinations);
277 }
278 
279 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
280 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
281 static ParseResult parseSwitchOpCases(
282     OpAsmParser &parser, Type flagType, ElementsAttr &caseValues,
283     SmallVectorImpl<Block *> &caseDestinations,
284     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
285     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
286   SmallVector<APInt> values;
287   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
288   do {
289     int64_t value = 0;
290     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
291     if (values.empty() && !integerParseResult.hasValue())
292       return success();
293 
294     if (!integerParseResult.hasValue() || integerParseResult.getValue())
295       return failure();
296     values.push_back(APInt(bitWidth, value));
297 
298     Block *destination;
299     SmallVector<OpAsmParser::UnresolvedOperand> operands;
300     SmallVector<Type> operandTypes;
301     if (parser.parseColon() || parser.parseSuccessor(destination))
302       return failure();
303     if (!parser.parseOptionalLParen()) {
304       if (parser.parseRegionArgumentList(operands) ||
305           parser.parseColonTypeList(operandTypes) || parser.parseRParen())
306         return failure();
307     }
308     caseDestinations.push_back(destination);
309     caseOperands.emplace_back(operands);
310     caseOperandTypes.emplace_back(operandTypes);
311   } while (!parser.parseOptionalComma());
312 
313   ShapedType caseValueType =
314       VectorType::get(static_cast<int64_t>(values.size()), flagType);
315   caseValues = DenseIntElementsAttr::get(caseValueType, values);
316   return success();
317 }
318 
319 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
320                                ElementsAttr caseValues,
321                                SuccessorRange caseDestinations,
322                                OperandRangeRange caseOperands,
323                                const TypeRangeRange &caseOperandTypes) {
324   if (!caseValues)
325     return;
326 
327   size_t index = 0;
328   llvm::interleave(
329       llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
330       [&](auto i) {
331         p << "  ";
332         p << std::get<0>(i).getLimitedValue();
333         p << ": ";
334         p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
335       },
336       [&] {
337         p << ',';
338         p.printNewline();
339       });
340   p.printNewline();
341 }
342 
343 LogicalResult SwitchOp::verify() {
344   if ((!getCaseValues() && !getCaseDestinations().empty()) ||
345       (getCaseValues() &&
346        getCaseValues()->size() !=
347            static_cast<int64_t>(getCaseDestinations().size())))
348     return emitOpError("expects number of case values to match number of "
349                        "case destinations");
350   if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
351     return emitError("expects number of branch weights to match number of "
352                      "successors: ")
353            << getBranchWeights()->size() << " vs " << getNumSuccessors();
354   return success();
355 }
356 
357 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
358   assert(index < getNumSuccessors() && "invalid successor index");
359   return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
360                                       : getCaseOperandsMutable(index - 1));
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // Code for LLVM::GEPOp.
365 //===----------------------------------------------------------------------===//
366 
367 constexpr int GEPOp::kDynamicIndex;
368 
369 /// Populates `indices` with positions of GEP indices that would correspond to
370 /// LLVMStructTypes potentially nested in the given type. The type currently
371 /// visited gets `currentIndex` and LLVM container types are visited
372 /// recursively. The recursion is bounded and takes care of recursive types by
373 /// means of the `visited` set.
374 static void recordStructIndices(Type type, unsigned currentIndex,
375                                 SmallVectorImpl<unsigned> &indices,
376                                 SmallVectorImpl<unsigned> *structSizes,
377                                 SmallPtrSet<Type, 4> &visited) {
378   if (visited.contains(type))
379     return;
380 
381   visited.insert(type);
382 
383   llvm::TypeSwitch<Type>(type)
384       .Case<LLVMStructType>([&](LLVMStructType structType) {
385         indices.push_back(currentIndex);
386         if (structSizes)
387           structSizes->push_back(structType.getBody().size());
388         for (Type elementType : structType.getBody())
389           recordStructIndices(elementType, currentIndex + 1, indices,
390                               structSizes, visited);
391       })
392       .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
393             LLVMArrayType>([&](auto containerType) {
394         recordStructIndices(containerType.getElementType(), currentIndex + 1,
395                             indices, structSizes, visited);
396       });
397 }
398 
399 /// Populates `indices` with positions of GEP indices that correspond to
400 /// LLVMStructTypes potentially nested in the given `baseGEPType`, which must
401 /// be either an LLVMPointer type or a vector thereof. If `structSizes` is
402 /// provided, it is populated with sizes of the indexed structs for bounds
403 /// verification purposes.
404 static void
405 findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices,
406                        SmallVectorImpl<unsigned> *structSizes = nullptr) {
407   Type type = baseGEPType;
408   if (auto vectorType = type.dyn_cast<VectorType>())
409     type = vectorType.getElementType();
410   if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
411     type = scalableVectorType.getElementType();
412   if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
413     type = fixedVectorType.getElementType();
414 
415   Type pointeeType = type.cast<LLVMPointerType>().getElementType();
416   SmallPtrSet<Type, 4> visited;
417   recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes,
418                       visited);
419 }
420 
421 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
422                   Value basePtr, ValueRange operands,
423                   ArrayRef<NamedAttribute> attributes) {
424   build(builder, result, resultType, basePtr, operands,
425         SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex),
426         attributes);
427 }
428 
429 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
430                   Value basePtr, ValueRange indices,
431                   ArrayRef<int32_t> structIndices,
432                   ArrayRef<NamedAttribute> attributes) {
433   SmallVector<Value> remainingIndices;
434   SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
435                                             structIndices.end());
436   SmallVector<unsigned> structRelatedPositions;
437   findKnownStructIndices(basePtr.getType(), structRelatedPositions);
438 
439   SmallVector<unsigned> operandsToErase;
440   for (unsigned pos : structRelatedPositions) {
441     // GEP may not be indexing as deep as some structs are located.
442     if (pos >= structIndices.size())
443       continue;
444 
445     // If the index is already static, it's fine.
446     if (structIndices[pos] != kDynamicIndex)
447       continue;
448 
449     // Find the corresponding operand.
450     unsigned operandPos =
451         std::count(structIndices.begin(), std::next(structIndices.begin(), pos),
452                    kDynamicIndex);
453 
454     // Extract the constant value from the operand and put it into the attribute
455     // instead.
456     APInt staticIndexValue;
457     bool matched =
458         matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue));
459     (void)matched;
460     assert(matched && "index into a struct must be a constant");
461     assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) &&
462            "struct index underflows 32-bit integer");
463     assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) &&
464            "struct index overflows 32-bit integer");
465     auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
466     updatedStructIndices[pos] = staticIndex;
467     operandsToErase.push_back(operandPos);
468   }
469 
470   for (unsigned i = 0, e = indices.size(); i < e; ++i) {
471     if (!llvm::is_contained(operandsToErase, i))
472       remainingIndices.push_back(indices[i]);
473   }
474 
475   assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
476                                         updatedStructIndices, kDynamicIndex)) &&
477          "expected as many index operands as dynamic index attr elements");
478 
479   result.addTypes(resultType);
480   result.addAttributes(attributes);
481   result.addAttribute("structIndices",
482                       builder.getI32TensorAttr(updatedStructIndices));
483   result.addOperands(basePtr);
484   result.addOperands(remainingIndices);
485 }
486 
487 static ParseResult
488 parseGEPIndices(OpAsmParser &parser,
489                 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
490                 DenseIntElementsAttr &structIndices) {
491   SmallVector<int32_t> constantIndices;
492   do {
493     int32_t constantIndex;
494     OptionalParseResult parsedInteger =
495         parser.parseOptionalInteger(constantIndex);
496     if (parsedInteger.hasValue()) {
497       if (failed(parsedInteger.getValue()))
498         return failure();
499       constantIndices.push_back(constantIndex);
500       continue;
501     }
502 
503     constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
504     if (failed(parser.parseOperand(indices.emplace_back())))
505       return failure();
506   } while (succeeded(parser.parseOptionalComma()));
507 
508   structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
509   return success();
510 }
511 
512 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
513                             OperandRange indices,
514                             DenseIntElementsAttr structIndices) {
515   unsigned operandIdx = 0;
516   llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
517                         [&](int32_t cst) {
518                           if (cst == LLVM::GEPOp::kDynamicIndex)
519                             printer.printOperand(indices[operandIdx++]);
520                           else
521                             printer << cst;
522                         });
523 }
524 
525 LogicalResult LLVM::GEPOp::verify() {
526   SmallVector<unsigned> indices;
527   SmallVector<unsigned> structSizes;
528   findKnownStructIndices(getBase().getType(), indices, &structSizes);
529   DenseIntElementsAttr structIndices = getStructIndices();
530   for (unsigned i : llvm::seq<unsigned>(0, indices.size())) {
531     unsigned index = indices[i];
532     // GEP may not be indexing as deep as some structs nested in the type.
533     if (index >= structIndices.getNumElements())
534       continue;
535 
536     int32_t staticIndex = structIndices.getValues<int32_t>()[index];
537     if (staticIndex == LLVM::GEPOp::kDynamicIndex)
538       return emitOpError() << "expected index " << index
539                            << " indexing a struct to be constant";
540     if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
541       return emitOpError() << "index " << index
542                            << " indexing a struct is out of bounds";
543   }
544   return success();
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Builder, printer and parser for for LLVM::LoadOp.
549 //===----------------------------------------------------------------------===//
550 
551 LogicalResult verifySymbolAttribute(
552     Operation *op, StringRef attributeName,
553     llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
554         verifySymbolType) {
555   if (Attribute attribute = op->getAttr(attributeName)) {
556     // The attribute is already verified to be a symbol ref array attribute via
557     // a constraint in the operation definition.
558     for (SymbolRefAttr symbolRef :
559          attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
560       StringAttr metadataName = symbolRef.getRootReference();
561       StringAttr symbolName = symbolRef.getLeafReference();
562       // We want @metadata::@symbol, not just @symbol
563       if (metadataName == symbolName) {
564         return op->emitOpError() << "expected '" << symbolRef
565                                  << "' to specify a fully qualified reference";
566       }
567       auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
568           op->getParentOp(), metadataName);
569       if (!metadataOp)
570         return op->emitOpError()
571                << "expected '" << symbolRef << "' to reference a metadata op";
572       Operation *symbolOp =
573           SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
574       if (!symbolOp)
575         return op->emitOpError()
576                << "expected '" << symbolRef << "' to be a valid reference";
577       if (failed(verifySymbolType(symbolOp, symbolRef))) {
578         return failure();
579       }
580     }
581   }
582   return success();
583 }
584 
585 // Verifies that metadata ops are wired up properly.
586 template <typename OpTy>
587 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
588   auto verifySymbolType = [op](Operation *symbolOp,
589                                SymbolRefAttr symbolRef) -> LogicalResult {
590     if (!isa<OpTy>(symbolOp)) {
591       return op->emitOpError()
592              << "expected '" << symbolRef << "' to resolve to a "
593              << OpTy::getOperationName();
594     }
595     return success();
596   };
597 
598   return verifySymbolAttribute(op, attributeName, verifySymbolType);
599 }
600 
601 static LogicalResult verifyMemoryOpMetadata(Operation *op) {
602   // access_groups
603   if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
604           op, LLVMDialect::getAccessGroupsAttrName())))
605     return failure();
606 
607   // alias_scopes
608   if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
609           op, LLVMDialect::getAliasScopesAttrName())))
610     return failure();
611 
612   // noalias_scopes
613   if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
614           op, LLVMDialect::getNoAliasScopesAttrName())))
615     return failure();
616 
617   return success();
618 }
619 
620 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
621 
622 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
623                    Value addr, unsigned alignment, bool isVolatile,
624                    bool isNonTemporal) {
625   result.addOperands(addr);
626   result.addTypes(t);
627   if (isVolatile)
628     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
629   if (isNonTemporal)
630     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
631   if (alignment != 0)
632     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
633 }
634 
635 void LoadOp::print(OpAsmPrinter &p) {
636   p << ' ';
637   if (getVolatile_())
638     p << "volatile ";
639   p << getAddr();
640   p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
641   p << " : " << getAddr().getType();
642 }
643 
644 // Extract the pointee type from the LLVM pointer type wrapped in MLIR.  Return
645 // the resulting type wrapped in MLIR, or nullptr on error.
646 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
647                                     SMLoc trailingTypeLoc) {
648   auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
649   if (!llvmTy)
650     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
651            nullptr;
652   return llvmTy.getElementType();
653 }
654 
655 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
656 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
657   OpAsmParser::UnresolvedOperand addr;
658   Type type;
659   SMLoc trailingTypeLoc;
660 
661   if (succeeded(parser.parseOptionalKeyword("volatile")))
662     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
663 
664   if (parser.parseOperand(addr) ||
665       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
666       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
667       parser.resolveOperand(addr, type, result.operands))
668     return failure();
669 
670   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
671 
672   result.addTypes(elemTy);
673   return success();
674 }
675 
676 //===----------------------------------------------------------------------===//
677 // Builder, printer and parser for LLVM::StoreOp.
678 //===----------------------------------------------------------------------===//
679 
680 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
681 
682 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
683                     Value addr, unsigned alignment, bool isVolatile,
684                     bool isNonTemporal) {
685   result.addOperands({value, addr});
686   result.addTypes({});
687   if (isVolatile)
688     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
689   if (isNonTemporal)
690     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
691   if (alignment != 0)
692     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
693 }
694 
695 void StoreOp::print(OpAsmPrinter &p) {
696   p << ' ';
697   if (getVolatile_())
698     p << "volatile ";
699   p << getValue() << ", " << getAddr();
700   p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
701   p << " : " << getAddr().getType();
702 }
703 
704 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
705 //                 attribute-dict? `:` type
706 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
707   OpAsmParser::UnresolvedOperand addr, value;
708   Type type;
709   SMLoc trailingTypeLoc;
710 
711   if (succeeded(parser.parseOptionalKeyword("volatile")))
712     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
713 
714   if (parser.parseOperand(value) || parser.parseComma() ||
715       parser.parseOperand(addr) ||
716       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
717       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
718     return failure();
719 
720   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
721   if (!elemTy)
722     return failure();
723 
724   if (parser.resolveOperand(value, elemTy, result.operands) ||
725       parser.resolveOperand(addr, type, result.operands))
726     return failure();
727 
728   return success();
729 }
730 
731 ///===---------------------------------------------------------------------===//
732 /// LLVM::InvokeOp
733 ///===---------------------------------------------------------------------===//
734 
735 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
736   assert(index < getNumSuccessors() && "invalid successor index");
737   return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
738                                       : getUnwindDestOperandsMutable());
739 }
740 
741 LogicalResult InvokeOp::verify() {
742   if (getNumResults() > 1)
743     return emitOpError("must have 0 or 1 result");
744 
745   Block *unwindDest = getUnwindDest();
746   if (unwindDest->empty())
747     return emitError("must have at least one operation in unwind destination");
748 
749   // In unwind destination, first operation must be LandingpadOp
750   if (!isa<LandingpadOp>(unwindDest->front()))
751     return emitError("first operation in unwind destination should be a "
752                      "llvm.landingpad operation");
753 
754   return success();
755 }
756 
757 void InvokeOp::print(OpAsmPrinter &p) {
758   auto callee = getCallee();
759   bool isDirect = callee.hasValue();
760 
761   p << ' ';
762 
763   // Either function name or pointer
764   if (isDirect)
765     p.printSymbolName(callee.getValue());
766   else
767     p << getOperand(0);
768 
769   p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
770   p << " to ";
771   p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
772   p << " unwind ";
773   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
774 
775   p.printOptionalAttrDict((*this)->getAttrs(),
776                           {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
777   p << " : ";
778   p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
779                         getResultTypes());
780 }
781 
782 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
783 ///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
784 ///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
785 ///                  attribute-dict? `:` function-type
786 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
787   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
788   FunctionType funcType;
789   SymbolRefAttr funcAttr;
790   SMLoc trailingTypeLoc;
791   Block *normalDest, *unwindDest;
792   SmallVector<Value, 4> normalOperands, unwindOperands;
793   Builder &builder = parser.getBuilder();
794 
795   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
796   // case of an indirect call, there will be 1 operand before `(`.  In case of a
797   // direct call, there will be no operands and the parser will stop at the
798   // function identifier without complaining.
799   if (parser.parseOperandList(operands))
800     return failure();
801   bool isDirect = operands.empty();
802 
803   // Optionally parse a function identifier.
804   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
805     return failure();
806 
807   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
808       parser.parseKeyword("to") ||
809       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
810       parser.parseKeyword("unwind") ||
811       parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
812       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
813       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
814     return failure();
815 
816   if (isDirect) {
817     // Make sure types match.
818     if (parser.resolveOperands(operands, funcType.getInputs(),
819                                parser.getNameLoc(), result.operands))
820       return failure();
821     result.addTypes(funcType.getResults());
822   } else {
823     // Construct the LLVM IR Dialect function type that the first operand
824     // should match.
825     if (funcType.getNumResults() > 1)
826       return parser.emitError(trailingTypeLoc,
827                               "expected function with 0 or 1 result");
828 
829     Type llvmResultType;
830     if (funcType.getNumResults() == 0) {
831       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
832     } else {
833       llvmResultType = funcType.getResult(0);
834       if (!isCompatibleType(llvmResultType))
835         return parser.emitError(trailingTypeLoc,
836                                 "expected result to have LLVM type");
837     }
838 
839     SmallVector<Type, 8> argTypes;
840     argTypes.reserve(funcType.getNumInputs());
841     for (Type ty : funcType.getInputs()) {
842       if (isCompatibleType(ty))
843         argTypes.push_back(ty);
844       else
845         return parser.emitError(trailingTypeLoc,
846                                 "expected LLVM types as inputs");
847     }
848 
849     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
850     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
851 
852     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
853 
854     // Make sure that the first operand (indirect callee) matches the wrapped
855     // LLVM IR function type, and that the types of the other call operands
856     // match the types of the function arguments.
857     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
858         parser.resolveOperands(funcArguments, funcType.getInputs(),
859                                parser.getNameLoc(), result.operands))
860       return failure();
861 
862     result.addTypes(llvmResultType);
863   }
864   result.addSuccessors({normalDest, unwindDest});
865   result.addOperands(normalOperands);
866   result.addOperands(unwindOperands);
867 
868   result.addAttribute(
869       InvokeOp::getOperandSegmentSizeAttr(),
870       builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
871                                 static_cast<int32_t>(normalOperands.size()),
872                                 static_cast<int32_t>(unwindOperands.size())}));
873   return success();
874 }
875 
876 ///===----------------------------------------------------------------------===//
877 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
878 ///===----------------------------------------------------------------------===//
879 
880 LogicalResult LandingpadOp::verify() {
881   Value value;
882   if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
883     if (!func.getPersonality().hasValue())
884       return emitError(
885           "llvm.landingpad needs to be in a function with a personality");
886   }
887 
888   if (!getCleanup() && getOperands().empty())
889     return emitError("landingpad instruction expects at least one clause or "
890                      "cleanup attribute");
891 
892   for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
893     value = getOperand(idx);
894     bool isFilter = value.getType().isa<LLVMArrayType>();
895     if (isFilter) {
896       // FIXME: Verify filter clauses when arrays are appropriately handled
897     } else {
898       // catch - global addresses only.
899       // Bitcast ops should have global addresses as their args.
900       if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
901         if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
902           continue;
903         return emitError("constant clauses expected").attachNote(bcOp.getLoc())
904                << "global addresses expected as operand to "
905                   "bitcast used in clauses for landingpad";
906       }
907       // NullOp and AddressOfOp allowed
908       if (value.getDefiningOp<NullOp>())
909         continue;
910       if (value.getDefiningOp<AddressOfOp>())
911         continue;
912       return emitError("clause #")
913              << idx << " is not a known constant - null, addressof, bitcast";
914     }
915   }
916   return success();
917 }
918 
919 void LandingpadOp::print(OpAsmPrinter &p) {
920   p << (getCleanup() ? " cleanup " : " ");
921 
922   // Clauses
923   for (auto value : getOperands()) {
924     // Similar to llvm - if clause is an array type then it is filter
925     // clause else catch clause
926     bool isArrayTy = value.getType().isa<LLVMArrayType>();
927     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
928       << value.getType() << ") ";
929   }
930 
931   p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
932 
933   p << ": " << getType();
934 }
935 
936 /// <operation> ::= `llvm.landingpad` `cleanup`?
937 ///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
938 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
939   // Check for cleanup
940   if (succeeded(parser.parseOptionalKeyword("cleanup")))
941     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
942 
943   // Parse clauses with types
944   while (succeeded(parser.parseOptionalLParen()) &&
945          (succeeded(parser.parseOptionalKeyword("filter")) ||
946           succeeded(parser.parseOptionalKeyword("catch")))) {
947     OpAsmParser::UnresolvedOperand operand;
948     Type ty;
949     if (parser.parseOperand(operand) || parser.parseColon() ||
950         parser.parseType(ty) ||
951         parser.resolveOperand(operand, ty, result.operands) ||
952         parser.parseRParen())
953       return failure();
954   }
955 
956   Type type;
957   if (parser.parseColon() || parser.parseType(type))
958     return failure();
959 
960   result.addTypes(type);
961   return success();
962 }
963 
964 //===----------------------------------------------------------------------===//
965 // Verifying/Printing/parsing for LLVM::CallOp.
966 //===----------------------------------------------------------------------===//
967 
968 LogicalResult CallOp::verify() {
969   if (getNumResults() > 1)
970     return emitOpError("must have 0 or 1 result");
971 
972   // Type for the callee, we'll get it differently depending if it is a direct
973   // or indirect call.
974   Type fnType;
975 
976   bool isIndirect = false;
977 
978   // If this is an indirect call, the callee attribute is missing.
979   FlatSymbolRefAttr calleeName = getCalleeAttr();
980   if (!calleeName) {
981     isIndirect = true;
982     if (!getNumOperands())
983       return emitOpError(
984           "must have either a `callee` attribute or at least an operand");
985     auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
986     if (!ptrType)
987       return emitOpError("indirect call expects a pointer as callee: ")
988              << ptrType;
989     fnType = ptrType.getElementType();
990   } else {
991     Operation *callee =
992         SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
993     if (!callee)
994       return emitOpError()
995              << "'" << calleeName.getValue()
996              << "' does not reference a symbol in the current scope";
997     auto fn = dyn_cast<LLVMFuncOp>(callee);
998     if (!fn)
999       return emitOpError() << "'" << calleeName.getValue()
1000                            << "' does not reference a valid LLVM function";
1001 
1002     fnType = fn.getFunctionType();
1003   }
1004 
1005   LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
1006   if (!funcType)
1007     return emitOpError("callee does not have a functional type: ") << fnType;
1008 
1009   // Verify that the operand and result types match the callee.
1010 
1011   if (!funcType.isVarArg() &&
1012       funcType.getNumParams() != (getNumOperands() - isIndirect))
1013     return emitOpError() << "incorrect number of operands ("
1014                          << (getNumOperands() - isIndirect)
1015                          << ") for callee (expecting: "
1016                          << funcType.getNumParams() << ")";
1017 
1018   if (funcType.getNumParams() > (getNumOperands() - isIndirect))
1019     return emitOpError() << "incorrect number of operands ("
1020                          << (getNumOperands() - isIndirect)
1021                          << ") for varargs callee (expecting at least: "
1022                          << funcType.getNumParams() << ")";
1023 
1024   for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1025     if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1026       return emitOpError() << "operand type mismatch for operand " << i << ": "
1027                            << getOperand(i + isIndirect).getType()
1028                            << " != " << funcType.getParamType(i);
1029 
1030   if (getNumResults() == 0 &&
1031       !funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1032     return emitOpError() << "expected function call to produce a value";
1033 
1034   if (getNumResults() != 0 &&
1035       funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1036     return emitOpError()
1037            << "calling function with void result must not produce values";
1038 
1039   if (getNumResults() > 1)
1040     return emitOpError()
1041            << "expected LLVM function call to produce 0 or 1 result";
1042 
1043   if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
1044     return emitOpError() << "result type mismatch: " << getResult(0).getType()
1045                          << " != " << funcType.getReturnType();
1046 
1047   return success();
1048 }
1049 
1050 void CallOp::print(OpAsmPrinter &p) {
1051   auto callee = getCallee();
1052   bool isDirect = callee.hasValue();
1053 
1054   // Print the direct callee if present as a function attribute, or an indirect
1055   // callee (first operand) otherwise.
1056   p << ' ';
1057   if (isDirect)
1058     p.printSymbolName(callee.getValue());
1059   else
1060     p << getOperand(0);
1061 
1062   auto args = getOperands().drop_front(isDirect ? 0 : 1);
1063   p << '(' << args << ')';
1064   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
1065 
1066   // Reconstruct the function MLIR function type from operand and result types.
1067   p << " : ";
1068   p.printFunctionalType(args.getTypes(), getResultTypes());
1069 }
1070 
1071 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
1072 //                 attribute-dict? `:` function-type
1073 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1074   SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1075   Type type;
1076   SymbolRefAttr funcAttr;
1077   SMLoc trailingTypeLoc;
1078 
1079   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
1080   // case of an indirect call, there will be 1 operand before `(`.  In case of a
1081   // direct call, there will be no operands and the parser will stop at the
1082   // function identifier without complaining.
1083   if (parser.parseOperandList(operands))
1084     return failure();
1085   bool isDirect = operands.empty();
1086 
1087   // Optionally parse a function identifier.
1088   if (isDirect)
1089     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1090       return failure();
1091 
1092   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1093       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1094       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
1095     return failure();
1096 
1097   auto funcType = type.dyn_cast<FunctionType>();
1098   if (!funcType)
1099     return parser.emitError(trailingTypeLoc, "expected function type");
1100   if (funcType.getNumResults() > 1)
1101     return parser.emitError(trailingTypeLoc,
1102                             "expected function with 0 or 1 result");
1103   if (isDirect) {
1104     // Make sure types match.
1105     if (parser.resolveOperands(operands, funcType.getInputs(),
1106                                parser.getNameLoc(), result.operands))
1107       return failure();
1108     if (funcType.getNumResults() != 0 &&
1109         !funcType.getResult(0).isa<LLVM::LLVMVoidType>())
1110       result.addTypes(funcType.getResults());
1111   } else {
1112     Builder &builder = parser.getBuilder();
1113     Type llvmResultType;
1114     if (funcType.getNumResults() == 0) {
1115       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
1116     } else {
1117       llvmResultType = funcType.getResult(0);
1118       if (!isCompatibleType(llvmResultType))
1119         return parser.emitError(trailingTypeLoc,
1120                                 "expected result to have LLVM type");
1121     }
1122 
1123     SmallVector<Type, 8> argTypes;
1124     argTypes.reserve(funcType.getNumInputs());
1125     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
1126       auto argType = funcType.getInput(i);
1127       if (!isCompatibleType(argType))
1128         return parser.emitError(trailingTypeLoc,
1129                                 "expected LLVM types as inputs");
1130       argTypes.push_back(argType);
1131     }
1132     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
1133     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
1134 
1135     auto funcArguments =
1136         ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front();
1137 
1138     // Make sure that the first operand (indirect callee) matches the wrapped
1139     // LLVM IR function type, and that the types of the other call operands
1140     // match the types of the function arguments.
1141     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
1142         parser.resolveOperands(funcArguments, funcType.getInputs(),
1143                                parser.getNameLoc(), result.operands))
1144       return failure();
1145 
1146     if (!llvmResultType.isa<LLVM::LLVMVoidType>())
1147       result.addTypes(llvmResultType);
1148   }
1149 
1150   return success();
1151 }
1152 
1153 //===----------------------------------------------------------------------===//
1154 // Printing/parsing for LLVM::ExtractElementOp.
1155 //===----------------------------------------------------------------------===//
1156 // Expects vector to be of wrapped LLVM vector type and position to be of
1157 // wrapped LLVM i32 type.
1158 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
1159                                    Value vector, Value position,
1160                                    ArrayRef<NamedAttribute> attrs) {
1161   auto vectorType = vector.getType();
1162   auto llvmType = LLVM::getVectorElementType(vectorType);
1163   build(b, result, llvmType, vector, position);
1164   result.addAttributes(attrs);
1165 }
1166 
1167 void ExtractElementOp::print(OpAsmPrinter &p) {
1168   p << ' ' << getVector() << "[" << getPosition() << " : "
1169     << getPosition().getType() << "]";
1170   p.printOptionalAttrDict((*this)->getAttrs());
1171   p << " : " << getVector().getType();
1172 }
1173 
1174 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
1175 //                 attribute-dict? `:` type
1176 ParseResult ExtractElementOp::parse(OpAsmParser &parser,
1177                                     OperationState &result) {
1178   SMLoc loc;
1179   OpAsmParser::UnresolvedOperand vector, position;
1180   Type type, positionType;
1181   if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
1182       parser.parseLSquare() || parser.parseOperand(position) ||
1183       parser.parseColonType(positionType) || parser.parseRSquare() ||
1184       parser.parseOptionalAttrDict(result.attributes) ||
1185       parser.parseColonType(type) ||
1186       parser.resolveOperand(vector, type, result.operands) ||
1187       parser.resolveOperand(position, positionType, result.operands))
1188     return failure();
1189   if (!LLVM::isCompatibleVectorType(type))
1190     return parser.emitError(
1191         loc, "expected LLVM dialect-compatible vector type for operand #1");
1192   result.addTypes(LLVM::getVectorElementType(type));
1193   return success();
1194 }
1195 
1196 LogicalResult ExtractElementOp::verify() {
1197   Type vectorType = getVector().getType();
1198   if (!LLVM::isCompatibleVectorType(vectorType))
1199     return emitOpError("expected LLVM dialect-compatible vector type for "
1200                        "operand #1, got")
1201            << vectorType;
1202   Type valueType = LLVM::getVectorElementType(vectorType);
1203   if (valueType != getRes().getType())
1204     return emitOpError() << "Type mismatch: extracting from " << vectorType
1205                          << " should produce " << valueType
1206                          << " but this op returns " << getRes().getType();
1207   return success();
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // Printing/parsing for LLVM::ExtractValueOp.
1212 //===----------------------------------------------------------------------===//
1213 
1214 void ExtractValueOp::print(OpAsmPrinter &p) {
1215   p << ' ' << getContainer() << getPosition();
1216   p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1217   p << " : " << getContainer().getType();
1218 }
1219 
1220 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1221 // `containerType`.  Position is an integer array attribute where each value
1222 // is a zero-based position of the element in the aggregate type.  Return the
1223 // resulting type wrapped in MLIR, or nullptr on error.
1224 static Type getInsertExtractValueElementType(OpAsmParser &parser,
1225                                              Type containerType,
1226                                              ArrayAttr positionAttr,
1227                                              SMLoc attributeLoc,
1228                                              SMLoc typeLoc) {
1229   Type llvmType = containerType;
1230   if (!isCompatibleType(containerType))
1231     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
1232 
1233   // Infer the element type from the structure type: iteratively step inside the
1234   // type by taking the element type, indexed by the position attribute for
1235   // structures.  Check the position index before accessing, it is supposed to
1236   // be in bounds.
1237   for (Attribute subAttr : positionAttr) {
1238     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1239     if (!positionElementAttr)
1240       return parser.emitError(attributeLoc,
1241                               "expected an array of integer literals"),
1242              nullptr;
1243     int position = positionElementAttr.getInt();
1244     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1245       if (position < 0 ||
1246           static_cast<unsigned>(position) >= arrayType.getNumElements())
1247         return parser.emitError(attributeLoc, "position out of bounds"),
1248                nullptr;
1249       llvmType = arrayType.getElementType();
1250     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1251       if (position < 0 ||
1252           static_cast<unsigned>(position) >= structType.getBody().size())
1253         return parser.emitError(attributeLoc, "position out of bounds"),
1254                nullptr;
1255       llvmType = structType.getBody()[position];
1256     } else {
1257       return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
1258              nullptr;
1259     }
1260   }
1261   return llvmType;
1262 }
1263 
1264 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1265 // `containerType`. Returns null on failure.
1266 static Type getInsertExtractValueElementType(Type containerType,
1267                                              ArrayAttr positionAttr,
1268                                              Operation *op) {
1269   Type llvmType = containerType;
1270   if (!isCompatibleType(containerType)) {
1271     op->emitError("expected LLVM IR Dialect type, got ") << containerType;
1272     return {};
1273   }
1274 
1275   // Infer the element type from the structure type: iteratively step inside the
1276   // type by taking the element type, indexed by the position attribute for
1277   // structures.  Check the position index before accessing, it is supposed to
1278   // be in bounds.
1279   for (Attribute subAttr : positionAttr) {
1280     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1281     if (!positionElementAttr) {
1282       op->emitOpError("expected an array of integer literals, got: ")
1283           << subAttr;
1284       return {};
1285     }
1286     int position = positionElementAttr.getInt();
1287     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1288       if (position < 0 ||
1289           static_cast<unsigned>(position) >= arrayType.getNumElements()) {
1290         op->emitOpError("position out of bounds: ") << position;
1291         return {};
1292       }
1293       llvmType = arrayType.getElementType();
1294     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1295       if (position < 0 ||
1296           static_cast<unsigned>(position) >= structType.getBody().size()) {
1297         op->emitOpError("position out of bounds") << position;
1298         return {};
1299       }
1300       llvmType = structType.getBody()[position];
1301     } else {
1302       op->emitOpError("expected LLVM IR structure/array type, got: ")
1303           << llvmType;
1304       return {};
1305     }
1306   }
1307   return llvmType;
1308 }
1309 
1310 // <operation> ::= `llvm.extractvalue` ssa-use
1311 //                 `[` integer-literal (`,` integer-literal)* `]`
1312 //                 attribute-dict? `:` type
1313 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) {
1314   OpAsmParser::UnresolvedOperand container;
1315   Type containerType;
1316   ArrayAttr positionAttr;
1317   SMLoc attributeLoc, trailingTypeLoc;
1318 
1319   if (parser.parseOperand(container) ||
1320       parser.getCurrentLocation(&attributeLoc) ||
1321       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1322       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1323       parser.getCurrentLocation(&trailingTypeLoc) ||
1324       parser.parseType(containerType) ||
1325       parser.resolveOperand(container, containerType, result.operands))
1326     return failure();
1327 
1328   auto elementType = getInsertExtractValueElementType(
1329       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1330   if (!elementType)
1331     return failure();
1332 
1333   result.addTypes(elementType);
1334   return success();
1335 }
1336 
1337 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
1338   auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1339   OpFoldResult result = {};
1340   while (insertValueOp) {
1341     if (getPosition() == insertValueOp.getPosition())
1342       return insertValueOp.getValue();
1343     unsigned min =
1344         std::min(getPosition().size(), insertValueOp.getPosition().size());
1345     // If one is fully prefix of the other, stop propagating back as it will
1346     // miss dependencies. For instance, %3 should not fold to %f0 in the
1347     // following example:
1348     // ```
1349     //   %1 = llvm.insertvalue %f0, %0[0, 0] :
1350     //     !llvm.array<4 x !llvm.array<4xf32>>
1351     //   %2 = llvm.insertvalue %arr, %1[0] :
1352     //     !llvm.array<4 x !llvm.array<4xf32>>
1353     //   %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>>
1354     // ```
1355     if (getPosition().getValue().take_front(min) ==
1356         insertValueOp.getPosition().getValue().take_front(min))
1357       return result;
1358 
1359     // If neither a prefix, nor the exact position, we can extract out of the
1360     // value being inserted into. Moreover, we can try again if that operand
1361     // is itself an insertvalue expression.
1362     getContainerMutable().assign(insertValueOp.getContainer());
1363     result = getResult();
1364     insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1365   }
1366   return result;
1367 }
1368 
1369 LogicalResult ExtractValueOp::verify() {
1370   Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1371                                                     getPositionAttr(), *this);
1372   if (!valueType)
1373     return failure();
1374 
1375   if (getRes().getType() != valueType)
1376     return emitOpError() << "Type mismatch: extracting from "
1377                          << getContainer().getType() << " should produce "
1378                          << valueType << " but this op returns "
1379                          << getRes().getType();
1380   return success();
1381 }
1382 
1383 //===----------------------------------------------------------------------===//
1384 // Printing/parsing for LLVM::InsertElementOp.
1385 //===----------------------------------------------------------------------===//
1386 
1387 void InsertElementOp::print(OpAsmPrinter &p) {
1388   p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : "
1389     << getPosition().getType() << "]";
1390   p.printOptionalAttrDict((*this)->getAttrs());
1391   p << " : " << getVector().getType();
1392 }
1393 
1394 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
1395 //                 attribute-dict? `:` type
1396 ParseResult InsertElementOp::parse(OpAsmParser &parser,
1397                                    OperationState &result) {
1398   SMLoc loc;
1399   OpAsmParser::UnresolvedOperand vector, value, position;
1400   Type vectorType, positionType;
1401   if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
1402       parser.parseComma() || parser.parseOperand(vector) ||
1403       parser.parseLSquare() || parser.parseOperand(position) ||
1404       parser.parseColonType(positionType) || parser.parseRSquare() ||
1405       parser.parseOptionalAttrDict(result.attributes) ||
1406       parser.parseColonType(vectorType))
1407     return failure();
1408 
1409   if (!LLVM::isCompatibleVectorType(vectorType))
1410     return parser.emitError(
1411         loc, "expected LLVM dialect-compatible vector type for operand #1");
1412   Type valueType = LLVM::getVectorElementType(vectorType);
1413   if (!valueType)
1414     return failure();
1415 
1416   if (parser.resolveOperand(vector, vectorType, result.operands) ||
1417       parser.resolveOperand(value, valueType, result.operands) ||
1418       parser.resolveOperand(position, positionType, result.operands))
1419     return failure();
1420 
1421   result.addTypes(vectorType);
1422   return success();
1423 }
1424 
1425 LogicalResult InsertElementOp::verify() {
1426   Type valueType = LLVM::getVectorElementType(getVector().getType());
1427   if (valueType != getValue().getType())
1428     return emitOpError() << "Type mismatch: cannot insert "
1429                          << getValue().getType() << " into "
1430                          << getVector().getType();
1431   return success();
1432 }
1433 
1434 //===----------------------------------------------------------------------===//
1435 // Printing/parsing for LLVM::InsertValueOp.
1436 //===----------------------------------------------------------------------===//
1437 
1438 void InsertValueOp::print(OpAsmPrinter &p) {
1439   p << ' ' << getValue() << ", " << getContainer() << getPosition();
1440   p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1441   p << " : " << getContainer().getType();
1442 }
1443 
1444 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
1445 //                 `[` integer-literal (`,` integer-literal)* `]`
1446 //                 attribute-dict? `:` type
1447 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) {
1448   OpAsmParser::UnresolvedOperand container, value;
1449   Type containerType;
1450   ArrayAttr positionAttr;
1451   SMLoc attributeLoc, trailingTypeLoc;
1452 
1453   if (parser.parseOperand(value) || parser.parseComma() ||
1454       parser.parseOperand(container) ||
1455       parser.getCurrentLocation(&attributeLoc) ||
1456       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1457       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1458       parser.getCurrentLocation(&trailingTypeLoc) ||
1459       parser.parseType(containerType))
1460     return failure();
1461 
1462   auto valueType = getInsertExtractValueElementType(
1463       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1464   if (!valueType)
1465     return failure();
1466 
1467   if (parser.resolveOperand(container, containerType, result.operands) ||
1468       parser.resolveOperand(value, valueType, result.operands))
1469     return failure();
1470 
1471   result.addTypes(containerType);
1472   return success();
1473 }
1474 
1475 LogicalResult InsertValueOp::verify() {
1476   Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1477                                                     getPositionAttr(), *this);
1478   if (!valueType)
1479     return failure();
1480 
1481   if (getValue().getType() != valueType)
1482     return emitOpError() << "Type mismatch: cannot insert "
1483                          << getValue().getType() << " into "
1484                          << getContainer().getType();
1485 
1486   return success();
1487 }
1488 
1489 //===----------------------------------------------------------------------===//
1490 // Printing, parsing and verification for LLVM::ReturnOp.
1491 //===----------------------------------------------------------------------===//
1492 
1493 LogicalResult ReturnOp::verify() {
1494   if (getNumOperands() > 1)
1495     return emitOpError("expected at most 1 operand");
1496 
1497   if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
1498     Type expectedType = parent.getFunctionType().getReturnType();
1499     if (expectedType.isa<LLVMVoidType>()) {
1500       if (getNumOperands() == 0)
1501         return success();
1502       InFlightDiagnostic diag = emitOpError("expected no operands");
1503       diag.attachNote(parent->getLoc()) << "when returning from function";
1504       return diag;
1505     }
1506     if (getNumOperands() == 0) {
1507       if (expectedType.isa<LLVMVoidType>())
1508         return success();
1509       InFlightDiagnostic diag = emitOpError("expected 1 operand");
1510       diag.attachNote(parent->getLoc()) << "when returning from function";
1511       return diag;
1512     }
1513     if (expectedType != getOperand(0).getType()) {
1514       InFlightDiagnostic diag = emitOpError("mismatching result types");
1515       diag.attachNote(parent->getLoc()) << "when returning from function";
1516       return diag;
1517     }
1518   }
1519   return success();
1520 }
1521 
1522 //===----------------------------------------------------------------------===//
1523 // ResumeOp
1524 //===----------------------------------------------------------------------===//
1525 
1526 LogicalResult ResumeOp::verify() {
1527   if (!getValue().getDefiningOp<LandingpadOp>())
1528     return emitOpError("expects landingpad value as operand");
1529   // No check for personality of function - landingpad op verifies it.
1530   return success();
1531 }
1532 
1533 //===----------------------------------------------------------------------===//
1534 // Verifier for LLVM::AddressOfOp.
1535 //===----------------------------------------------------------------------===//
1536 
1537 template <typename OpTy>
1538 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
1539   Operation *module = parent;
1540   while (module && !satisfiesLLVMModule(module))
1541     module = module->getParentOp();
1542   assert(module && "unexpected operation outside of a module");
1543   return dyn_cast_or_null<OpTy>(
1544       mlir::SymbolTable::lookupSymbolIn(module, name));
1545 }
1546 
1547 GlobalOp AddressOfOp::getGlobal() {
1548   return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
1549                                               getGlobalName());
1550 }
1551 
1552 LLVMFuncOp AddressOfOp::getFunction() {
1553   return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
1554                                                 getGlobalName());
1555 }
1556 
1557 LogicalResult AddressOfOp::verify() {
1558   auto global = getGlobal();
1559   auto function = getFunction();
1560   if (!global && !function)
1561     return emitOpError(
1562         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1563 
1564   if (global &&
1565       LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) !=
1566           getResult().getType())
1567     return emitOpError(
1568         "the type must be a pointer to the type of the referenced global");
1569 
1570   if (function && LLVM::LLVMPointerType::get(function.getFunctionType()) !=
1571                       getResult().getType())
1572     return emitOpError(
1573         "the type must be a pointer to the type of the referenced function");
1574 
1575   return success();
1576 }
1577 
1578 //===----------------------------------------------------------------------===//
1579 // Builder, printer and verifier for LLVM::GlobalOp.
1580 //===----------------------------------------------------------------------===//
1581 
1582 /// Returns the name used for the linkage attribute. This *must* correspond to
1583 /// the name of the attribute in ODS.
1584 static StringRef getLinkageAttrName() { return "linkage"; }
1585 
1586 /// Returns the name used for the unnamed_addr attribute. This *must* correspond
1587 /// to the name of the attribute in ODS.
1588 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; }
1589 
1590 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1591                      bool isConstant, Linkage linkage, StringRef name,
1592                      Attribute value, uint64_t alignment, unsigned addrSpace,
1593                      bool dsoLocal, ArrayRef<NamedAttribute> attrs) {
1594   result.addAttribute(SymbolTable::getSymbolAttrName(),
1595                       builder.getStringAttr(name));
1596   result.addAttribute("global_type", TypeAttr::get(type));
1597   if (isConstant)
1598     result.addAttribute("constant", builder.getUnitAttr());
1599   if (value)
1600     result.addAttribute("value", value);
1601   if (dsoLocal)
1602     result.addAttribute("dso_local", builder.getUnitAttr());
1603 
1604   // Only add an alignment attribute if the "alignment" input
1605   // is different from 0. The value must also be a power of two, but
1606   // this is tested in GlobalOp::verify, not here.
1607   if (alignment != 0)
1608     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
1609 
1610   result.addAttribute(::getLinkageAttrName(),
1611                       LinkageAttr::get(builder.getContext(), linkage));
1612   if (addrSpace != 0)
1613     result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
1614   result.attributes.append(attrs.begin(), attrs.end());
1615   result.addRegion();
1616 }
1617 
1618 void GlobalOp::print(OpAsmPrinter &p) {
1619   p << ' ' << stringifyLinkage(getLinkage()) << ' ';
1620   if (auto unnamedAddr = getUnnamedAddr()) {
1621     StringRef str = stringifyUnnamedAddr(*unnamedAddr);
1622     if (!str.empty())
1623       p << str << ' ';
1624   }
1625   if (getConstant())
1626     p << "constant ";
1627   p.printSymbolName(getSymName());
1628   p << '(';
1629   if (auto value = getValueOrNull())
1630     p.printAttribute(value);
1631   p << ')';
1632   // Note that the alignment attribute is printed using the
1633   // default syntax here, even though it is an inherent attribute
1634   // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1635   p.printOptionalAttrDict((*this)->getAttrs(),
1636                           {SymbolTable::getSymbolAttrName(), "global_type",
1637                            "constant", "value", getLinkageAttrName(),
1638                            getUnnamedAddrAttrName()});
1639 
1640   // Print the trailing type unless it's a string global.
1641   if (getValueOrNull().dyn_cast_or_null<StringAttr>())
1642     return;
1643   p << " : " << getType();
1644 
1645   Region &initializer = getInitializerRegion();
1646   if (!initializer.empty()) {
1647     p << ' ';
1648     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1649   }
1650 }
1651 
1652 // Parses one of the keywords provided in the list `keywords` and returns the
1653 // position of the parsed keyword in the list. If none of the keywords from the
1654 // list is parsed, returns -1.
1655 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1656                                            ArrayRef<StringRef> keywords) {
1657   for (const auto &en : llvm::enumerate(keywords)) {
1658     if (succeeded(parser.parseOptionalKeyword(en.value())))
1659       return en.index();
1660   }
1661   return -1;
1662 }
1663 
1664 namespace {
1665 template <typename Ty>
1666 struct EnumTraits {};
1667 
1668 #define REGISTER_ENUM_TYPE(Ty)                                                 \
1669   template <>                                                                  \
1670   struct EnumTraits<Ty> {                                                      \
1671     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
1672     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
1673   }
1674 
1675 REGISTER_ENUM_TYPE(Linkage);
1676 REGISTER_ENUM_TYPE(UnnamedAddr);
1677 } // namespace
1678 
1679 /// Parse an enum from the keyword, or default to the provided default value.
1680 /// The return type is the enum type by default, unless overriden with the
1681 /// second template argument.
1682 template <typename EnumTy, typename RetTy = EnumTy>
1683 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
1684                                       OperationState &result,
1685                                       EnumTy defaultValue) {
1686   SmallVector<StringRef, 10> names;
1687   for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
1688     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1689 
1690   int index = parseOptionalKeywordAlternative(parser, names);
1691   if (index == -1)
1692     return static_cast<RetTy>(defaultValue);
1693   return static_cast<RetTy>(index);
1694 }
1695 
1696 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1697 //               `(` attribute? `)` align? attribute-list? (`:` type)? region?
1698 // align     ::= `align` `=` UINT64
1699 //
1700 // The type can be omitted for string attributes, in which case it will be
1701 // inferred from the value of the string as [strlen(value) x i8].
1702 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
1703   MLIRContext *ctx = parser.getContext();
1704   // Parse optional linkage, default to External.
1705   result.addAttribute(::getLinkageAttrName(),
1706                       LLVM::LinkageAttr::get(
1707                           ctx, parseOptionalLLVMKeyword<Linkage>(
1708                                    parser, result, LLVM::Linkage::External)));
1709   // Parse optional UnnamedAddr, default to None.
1710   result.addAttribute(::getUnnamedAddrAttrName(),
1711                       parser.getBuilder().getI64IntegerAttr(
1712                           parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
1713                               parser, result, LLVM::UnnamedAddr::None)));
1714 
1715   if (succeeded(parser.parseOptionalKeyword("constant")))
1716     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1717 
1718   StringAttr name;
1719   if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1720                              result.attributes) ||
1721       parser.parseLParen())
1722     return failure();
1723 
1724   Attribute value;
1725   if (parser.parseOptionalRParen()) {
1726     if (parser.parseAttribute(value, "value", result.attributes) ||
1727         parser.parseRParen())
1728       return failure();
1729   }
1730 
1731   SmallVector<Type, 1> types;
1732   if (parser.parseOptionalAttrDict(result.attributes) ||
1733       parser.parseOptionalColonTypeList(types))
1734     return failure();
1735 
1736   if (types.size() > 1)
1737     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1738 
1739   Region &initRegion = *result.addRegion();
1740   if (types.empty()) {
1741     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1742       MLIRContext *context = parser.getContext();
1743       auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1744                                                 strAttr.getValue().size());
1745       types.push_back(arrayType);
1746     } else {
1747       return parser.emitError(parser.getNameLoc(),
1748                               "type can only be omitted for string globals");
1749     }
1750   } else {
1751     OptionalParseResult parseResult =
1752         parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1753                                    /*argTypes=*/{});
1754     if (parseResult.hasValue() && failed(*parseResult))
1755       return failure();
1756   }
1757 
1758   result.addAttribute("global_type", TypeAttr::get(types[0]));
1759   return success();
1760 }
1761 
1762 static bool isZeroAttribute(Attribute value) {
1763   if (auto intValue = value.dyn_cast<IntegerAttr>())
1764     return intValue.getValue().isNullValue();
1765   if (auto fpValue = value.dyn_cast<FloatAttr>())
1766     return fpValue.getValue().isZero();
1767   if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
1768     return isZeroAttribute(splatValue.getSplatValue<Attribute>());
1769   if (auto elementsValue = value.dyn_cast<ElementsAttr>())
1770     return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1771   if (auto arrayValue = value.dyn_cast<ArrayAttr>())
1772     return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1773   return false;
1774 }
1775 
1776 LogicalResult GlobalOp::verify() {
1777   if (!LLVMPointerType::isValidElementType(getType()))
1778     return emitOpError(
1779         "expects type to be a valid element type for an LLVM pointer");
1780   if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
1781     return emitOpError("must appear at the module level");
1782 
1783   if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1784     auto type = getType().dyn_cast<LLVMArrayType>();
1785     IntegerType elementType =
1786         type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
1787     if (!elementType || elementType.getWidth() != 8 ||
1788         type.getNumElements() != strAttr.getValue().size())
1789       return emitOpError(
1790           "requires an i8 array type of the length equal to that of the string "
1791           "attribute");
1792   }
1793 
1794   if (getLinkage() == Linkage::Common) {
1795     if (Attribute value = getValueOrNull()) {
1796       if (!isZeroAttribute(value)) {
1797         return emitOpError()
1798                << "expected zero value for '"
1799                << stringifyLinkage(Linkage::Common) << "' linkage";
1800       }
1801     }
1802   }
1803 
1804   if (getLinkage() == Linkage::Appending) {
1805     if (!getType().isa<LLVMArrayType>()) {
1806       return emitOpError() << "expected array type for '"
1807                            << stringifyLinkage(Linkage::Appending)
1808                            << "' linkage";
1809     }
1810   }
1811 
1812   Optional<uint64_t> alignAttr = getAlignment();
1813   if (alignAttr.hasValue()) {
1814     uint64_t value = alignAttr.getValue();
1815     if (!llvm::isPowerOf2_64(value))
1816       return emitError() << "alignment attribute is not a power of 2";
1817   }
1818 
1819   return success();
1820 }
1821 
1822 LogicalResult GlobalOp::verifyRegions() {
1823   if (Block *b = getInitializerBlock()) {
1824     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1825     if (ret.operand_type_begin() == ret.operand_type_end())
1826       return emitOpError("initializer region cannot return void");
1827     if (*ret.operand_type_begin() != getType())
1828       return emitOpError("initializer region type ")
1829              << *ret.operand_type_begin() << " does not match global type "
1830              << getType();
1831 
1832     for (Operation &op : *b) {
1833       auto iface = dyn_cast<MemoryEffectOpInterface>(op);
1834       if (!iface || !iface.hasNoEffect())
1835         return op.emitError()
1836                << "ops with side effects not allowed in global initializers";
1837     }
1838 
1839     if (getValueOrNull())
1840       return emitOpError("cannot have both initializer value and region");
1841   }
1842 
1843   return success();
1844 }
1845 
1846 //===----------------------------------------------------------------------===//
1847 // LLVM::GlobalCtorsOp
1848 //===----------------------------------------------------------------------===//
1849 
1850 LogicalResult
1851 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1852   for (Attribute ctor : getCtors()) {
1853     if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this,
1854                                    symbolTable)))
1855       return failure();
1856   }
1857   return success();
1858 }
1859 
1860 LogicalResult GlobalCtorsOp::verify() {
1861   if (getCtors().size() != getPriorities().size())
1862     return emitError(
1863         "mismatch between the number of ctors and the number of priorities");
1864   return success();
1865 }
1866 
1867 //===----------------------------------------------------------------------===//
1868 // LLVM::GlobalDtorsOp
1869 //===----------------------------------------------------------------------===//
1870 
1871 LogicalResult
1872 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1873   for (Attribute dtor : getDtors()) {
1874     if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this,
1875                                    symbolTable)))
1876       return failure();
1877   }
1878   return success();
1879 }
1880 
1881 LogicalResult GlobalDtorsOp::verify() {
1882   if (getDtors().size() != getPriorities().size())
1883     return emitError(
1884         "mismatch between the number of dtors and the number of priorities");
1885   return success();
1886 }
1887 
1888 //===----------------------------------------------------------------------===//
1889 // Printing/parsing for LLVM::ShuffleVectorOp.
1890 //===----------------------------------------------------------------------===//
1891 // Expects vector to be of wrapped LLVM vector type and position to be of
1892 // wrapped LLVM i32 type.
1893 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
1894                                   Value v1, Value v2, ArrayAttr mask,
1895                                   ArrayRef<NamedAttribute> attrs) {
1896   auto containerType = v1.getType();
1897   auto vType = LLVM::getVectorType(
1898       LLVM::getVectorElementType(containerType), mask.size(),
1899       containerType.cast<VectorType>().isScalable());
1900   build(b, result, vType, v1, v2, mask);
1901   result.addAttributes(attrs);
1902 }
1903 
1904 void ShuffleVectorOp::print(OpAsmPrinter &p) {
1905   p << ' ' << getV1() << ", " << getV2() << " " << getMask();
1906   p.printOptionalAttrDict((*this)->getAttrs(), {"mask"});
1907   p << " : " << getV1().getType() << ", " << getV2().getType();
1908 }
1909 
1910 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1911 //                 `[` integer-literal (`,` integer-literal)* `]`
1912 //                 attribute-dict? `:` type
1913 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser,
1914                                    OperationState &result) {
1915   SMLoc loc;
1916   OpAsmParser::UnresolvedOperand v1, v2;
1917   ArrayAttr maskAttr;
1918   Type typeV1, typeV2;
1919   if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1920       parser.parseComma() || parser.parseOperand(v2) ||
1921       parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1922       parser.parseOptionalAttrDict(result.attributes) ||
1923       parser.parseColonType(typeV1) || parser.parseComma() ||
1924       parser.parseType(typeV2) ||
1925       parser.resolveOperand(v1, typeV1, result.operands) ||
1926       parser.resolveOperand(v2, typeV2, result.operands))
1927     return failure();
1928   if (!LLVM::isCompatibleVectorType(typeV1))
1929     return parser.emitError(
1930         loc, "expected LLVM IR dialect vector type for operand #1");
1931   auto vType =
1932       LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(),
1933                           typeV1.cast<VectorType>().isScalable());
1934   result.addTypes(vType);
1935   return success();
1936 }
1937 
1938 LogicalResult ShuffleVectorOp::verify() {
1939   Type type1 = getV1().getType();
1940   Type type2 = getV2().getType();
1941   if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
1942     return emitOpError("expected matching LLVM IR Dialect element types");
1943   if (LLVM::isScalableVectorType(type1))
1944     if (llvm::any_of(getMask(), [](Attribute attr) {
1945           return attr.cast<IntegerAttr>().getInt() != 0;
1946         }))
1947       return emitOpError("expected a splat operation for scalable vectors");
1948   return success();
1949 }
1950 
1951 //===----------------------------------------------------------------------===//
1952 // Implementations for LLVM::LLVMFuncOp.
1953 //===----------------------------------------------------------------------===//
1954 
1955 // Add the entry block to the function.
1956 Block *LLVMFuncOp::addEntryBlock() {
1957   assert(empty() && "function already has an entry block");
1958   assert(!isVarArg() && "unimplemented: non-external variadic functions");
1959 
1960   auto *entry = new Block;
1961   push_back(entry);
1962 
1963   // FIXME: Allow passing in proper locations for the entry arguments.
1964   LLVMFunctionType type = getFunctionType();
1965   for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
1966     entry->addArgument(type.getParamType(i), getLoc());
1967   return entry;
1968 }
1969 
1970 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
1971                        StringRef name, Type type, LLVM::Linkage linkage,
1972                        bool dsoLocal, ArrayRef<NamedAttribute> attrs,
1973                        ArrayRef<DictionaryAttr> argAttrs) {
1974   result.addRegion();
1975   result.addAttribute(SymbolTable::getSymbolAttrName(),
1976                       builder.getStringAttr(name));
1977   result.addAttribute(getFunctionTypeAttrName(result.name),
1978                       TypeAttr::get(type));
1979   result.addAttribute(::getLinkageAttrName(),
1980                       LinkageAttr::get(builder.getContext(), linkage));
1981   result.attributes.append(attrs.begin(), attrs.end());
1982   if (dsoLocal)
1983     result.addAttribute("dso_local", builder.getUnitAttr());
1984   if (argAttrs.empty())
1985     return;
1986 
1987   assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
1988          "expected as many argument attribute lists as arguments");
1989   function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
1990                                                 /*resultAttrs=*/llvm::None);
1991 }
1992 
1993 // Builds an LLVM function type from the given lists of input and output types.
1994 // Returns a null type if any of the types provided are non-LLVM types, or if
1995 // there is more than one output type.
1996 static Type
1997 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
1998                       ArrayRef<Type> outputs,
1999                       function_interface_impl::VariadicFlag variadicFlag) {
2000   Builder &b = parser.getBuilder();
2001   if (outputs.size() > 1) {
2002     parser.emitError(loc, "failed to construct function type: expected zero or "
2003                           "one function result");
2004     return {};
2005   }
2006 
2007   // Convert inputs to LLVM types, exit early on error.
2008   SmallVector<Type, 4> llvmInputs;
2009   for (auto t : inputs) {
2010     if (!isCompatibleType(t)) {
2011       parser.emitError(loc, "failed to construct function type: expected LLVM "
2012                             "type for function arguments");
2013       return {};
2014     }
2015     llvmInputs.push_back(t);
2016   }
2017 
2018   // No output is denoted as "void" in LLVM type system.
2019   Type llvmOutput =
2020       outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2021   if (!isCompatibleType(llvmOutput)) {
2022     parser.emitError(loc, "failed to construct function type: expected LLVM "
2023                           "type for function results")
2024         << llvmOutput;
2025     return {};
2026   }
2027   return LLVMFunctionType::get(llvmOutput, llvmInputs,
2028                                variadicFlag.isVariadic());
2029 }
2030 
2031 // Parses an LLVM function.
2032 //
2033 // operation ::= `llvm.func` linkage? function-signature function-attributes?
2034 //               function-body
2035 //
2036 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2037   // Default to external linkage if no keyword is provided.
2038   result.addAttribute(
2039       ::getLinkageAttrName(),
2040       LinkageAttr::get(parser.getContext(),
2041                        parseOptionalLLVMKeyword<Linkage>(
2042                            parser, result, LLVM::Linkage::External)));
2043 
2044   StringAttr nameAttr;
2045   SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
2046   SmallVector<NamedAttrList> argAttrs;
2047   SmallVector<NamedAttrList> resultAttrs;
2048   SmallVector<Type> argTypes;
2049   SmallVector<Type> resultTypes;
2050   SmallVector<Location> argLocations;
2051   bool isVariadic;
2052 
2053   auto signatureLocation = parser.getCurrentLocation();
2054   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2055                              result.attributes) ||
2056       function_interface_impl::parseFunctionSignature(
2057           parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
2058           argLocations, isVariadic, resultTypes, resultAttrs))
2059     return failure();
2060 
2061   auto type =
2062       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2063                             function_interface_impl::VariadicFlag(isVariadic));
2064   if (!type)
2065     return failure();
2066   result.addAttribute(FunctionOpInterface::getTypeAttrName(),
2067                       TypeAttr::get(type));
2068 
2069   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2070     return failure();
2071   function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
2072                                                 argAttrs, resultAttrs);
2073 
2074   auto *body = result.addRegion();
2075   OptionalParseResult parseResult = parser.parseOptionalRegion(
2076       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
2077   return failure(parseResult.hasValue() && failed(*parseResult));
2078 }
2079 
2080 // Print the LLVMFuncOp. Collects argument and result types and passes them to
2081 // helper functions. Drops "void" result since it cannot be parsed back. Skips
2082 // the external linkage since it is the default value.
2083 void LLVMFuncOp::print(OpAsmPrinter &p) {
2084   p << ' ';
2085   if (getLinkage() != LLVM::Linkage::External)
2086     p << stringifyLinkage(getLinkage()) << ' ';
2087   p.printSymbolName(getName());
2088 
2089   LLVMFunctionType fnType = getFunctionType();
2090   SmallVector<Type, 8> argTypes;
2091   SmallVector<Type, 1> resTypes;
2092   argTypes.reserve(fnType.getNumParams());
2093   for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2094     argTypes.push_back(fnType.getParamType(i));
2095 
2096   Type returnType = fnType.getReturnType();
2097   if (!returnType.isa<LLVMVoidType>())
2098     resTypes.push_back(returnType);
2099 
2100   function_interface_impl::printFunctionSignature(p, *this, argTypes,
2101                                                   isVarArg(), resTypes);
2102   function_interface_impl::printFunctionAttributes(
2103       p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
2104 
2105   // Print the body if this is not an external function.
2106   Region &body = getBody();
2107   if (!body.empty()) {
2108     p << ' ';
2109     p.printRegion(body, /*printEntryBlockArgs=*/false,
2110                   /*printBlockTerminators=*/true);
2111   }
2112 }
2113 
2114 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2115 // - functions don't have 'common' linkage
2116 // - external functions have 'external' or 'extern_weak' linkage;
2117 // - vararg is (currently) only supported for external functions;
2118 LogicalResult LLVMFuncOp::verify() {
2119   if (getLinkage() == LLVM::Linkage::Common)
2120     return emitOpError() << "functions cannot have '"
2121                          << stringifyLinkage(LLVM::Linkage::Common)
2122                          << "' linkage";
2123 
2124   // Check to see if this function has a void return with a result attribute to
2125   // it. It isn't clear what semantics we would assign to that.
2126   if (getFunctionType().getReturnType().isa<LLVMVoidType>() &&
2127       !getResultAttrs(0).empty()) {
2128     return emitOpError()
2129            << "cannot attach result attributes to functions with a void return";
2130   }
2131 
2132   if (isExternal()) {
2133     if (getLinkage() != LLVM::Linkage::External &&
2134         getLinkage() != LLVM::Linkage::ExternWeak)
2135       return emitOpError() << "external functions must have '"
2136                            << stringifyLinkage(LLVM::Linkage::External)
2137                            << "' or '"
2138                            << stringifyLinkage(LLVM::Linkage::ExternWeak)
2139                            << "' linkage";
2140     return success();
2141   }
2142 
2143   if (isVarArg())
2144     return emitOpError("only external functions can be variadic");
2145 
2146   return success();
2147 }
2148 
2149 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2150 /// - entry block arguments are of LLVM types.
2151 LogicalResult LLVMFuncOp::verifyRegions() {
2152   if (isExternal())
2153     return success();
2154 
2155   unsigned numArguments = getFunctionType().getNumParams();
2156   Block &entryBlock = front();
2157   for (unsigned i = 0; i < numArguments; ++i) {
2158     Type argType = entryBlock.getArgument(i).getType();
2159     if (!isCompatibleType(argType))
2160       return emitOpError("entry block argument #")
2161              << i << " is not of LLVM type";
2162   }
2163 
2164   return success();
2165 }
2166 
2167 //===----------------------------------------------------------------------===//
2168 // Verification for LLVM::ConstantOp.
2169 //===----------------------------------------------------------------------===//
2170 
2171 LogicalResult LLVM::ConstantOp::verify() {
2172   if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
2173     auto arrayType = getType().dyn_cast<LLVMArrayType>();
2174     if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
2175         !arrayType.getElementType().isInteger(8)) {
2176       return emitOpError() << "expected array type of "
2177                            << sAttr.getValue().size()
2178                            << " i8 elements for the string constant";
2179     }
2180     return success();
2181   }
2182   if (auto structType = getType().dyn_cast<LLVMStructType>()) {
2183     if (structType.getBody().size() != 2 ||
2184         structType.getBody()[0] != structType.getBody()[1]) {
2185       return emitError() << "expected struct type with two elements of the "
2186                             "same type, the type of a complex constant";
2187     }
2188 
2189     auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
2190     if (!arrayAttr || arrayAttr.size() != 2 ||
2191         arrayAttr[0].getType() != arrayAttr[1].getType()) {
2192       return emitOpError() << "expected array attribute with two elements, "
2193                               "representing a complex constant";
2194     }
2195 
2196     Type elementType = structType.getBody()[0];
2197     if (!elementType
2198              .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
2199       return emitError()
2200              << "expected struct element types to be floating point type or "
2201                 "integer type";
2202     }
2203     return success();
2204   }
2205   if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
2206     return emitOpError()
2207            << "only supports integer, float, string or elements attributes";
2208   return success();
2209 }
2210 
2211 // Constant op constant-folds to its value.
2212 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); }
2213 
2214 //===----------------------------------------------------------------------===//
2215 // Utility functions for parsing atomic ops
2216 //===----------------------------------------------------------------------===//
2217 
2218 // Helper function to parse a keyword into the specified attribute named by
2219 // `attrName`. The keyword must match one of the string values defined by the
2220 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
2221 // state.
2222 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
2223                                     StringRef attrName) {
2224   SMLoc loc;
2225   StringRef keyword;
2226   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
2227     return failure();
2228 
2229   // Replace the keyword `keyword` with an integer attribute.
2230   auto kind = symbolizeAtomicBinOp(keyword);
2231   if (!kind) {
2232     return parser.emitError(loc)
2233            << "'" << keyword << "' is an incorrect value of the '" << attrName
2234            << "' attribute";
2235   }
2236 
2237   auto value = static_cast<int64_t>(kind.getValue());
2238   auto attr = parser.getBuilder().getI64IntegerAttr(value);
2239   result.addAttribute(attrName, attr);
2240 
2241   return success();
2242 }
2243 
2244 // Helper function to parse a keyword into the specified attribute named by
2245 // `attrName`. The keyword must match one of the string values defined by the
2246 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
2247 // state.
2248 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
2249                                        OperationState &result,
2250                                        StringRef attrName) {
2251   SMLoc loc;
2252   StringRef ordering;
2253   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
2254     return failure();
2255 
2256   // Replace the keyword `ordering` with an integer attribute.
2257   auto kind = symbolizeAtomicOrdering(ordering);
2258   if (!kind) {
2259     return parser.emitError(loc)
2260            << "'" << ordering << "' is an incorrect value of the '" << attrName
2261            << "' attribute";
2262   }
2263 
2264   auto value = static_cast<int64_t>(kind.getValue());
2265   auto attr = parser.getBuilder().getI64IntegerAttr(value);
2266   result.addAttribute(attrName, attr);
2267 
2268   return success();
2269 }
2270 
2271 //===----------------------------------------------------------------------===//
2272 // Printer, parser and verifier for LLVM::AtomicRMWOp.
2273 //===----------------------------------------------------------------------===//
2274 
2275 void AtomicRMWOp::print(OpAsmPrinter &p) {
2276   p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", "
2277     << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' ';
2278   p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"});
2279   p << " : " << getRes().getType();
2280 }
2281 
2282 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
2283 //                 attribute-dict? `:` type
2284 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) {
2285   Type type;
2286   OpAsmParser::UnresolvedOperand ptr, val;
2287   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
2288       parser.parseComma() || parser.parseOperand(val) ||
2289       parseAtomicOrdering(parser, result, "ordering") ||
2290       parser.parseOptionalAttrDict(result.attributes) ||
2291       parser.parseColonType(type) ||
2292       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2293                             result.operands) ||
2294       parser.resolveOperand(val, type, result.operands))
2295     return failure();
2296 
2297   result.addTypes(type);
2298   return success();
2299 }
2300 
2301 LogicalResult AtomicRMWOp::verify() {
2302   auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2303   auto valType = getVal().getType();
2304   if (valType != ptrType.getElementType())
2305     return emitOpError("expected LLVM IR element type for operand #0 to "
2306                        "match type for operand #1");
2307   auto resType = getRes().getType();
2308   if (resType != valType)
2309     return emitOpError(
2310         "expected LLVM IR result type to match type for operand #1");
2311   if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) {
2312     if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2313       return emitOpError("expected LLVM IR floating point type");
2314   } else if (getBinOp() == AtomicBinOp::xchg) {
2315     auto intType = valType.dyn_cast<IntegerType>();
2316     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2317     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2318         intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
2319         !valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
2320         !valType.isa<Float64Type>())
2321       return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2322   } else {
2323     auto intType = valType.dyn_cast<IntegerType>();
2324     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2325     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2326         intBitWidth != 64)
2327       return emitOpError("expected LLVM IR integer type");
2328   }
2329 
2330   if (static_cast<unsigned>(getOrdering()) <
2331       static_cast<unsigned>(AtomicOrdering::monotonic))
2332     return emitOpError() << "expected at least '"
2333                          << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2334                          << "' ordering";
2335 
2336   return success();
2337 }
2338 
2339 //===----------------------------------------------------------------------===//
2340 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
2341 //===----------------------------------------------------------------------===//
2342 
2343 void AtomicCmpXchgOp::print(OpAsmPrinter &p) {
2344   p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' '
2345     << stringifyAtomicOrdering(getSuccessOrdering()) << ' '
2346     << stringifyAtomicOrdering(getFailureOrdering());
2347   p.printOptionalAttrDict((*this)->getAttrs(),
2348                           {"success_ordering", "failure_ordering"});
2349   p << " : " << getVal().getType();
2350 }
2351 
2352 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
2353 //                 keyword keyword attribute-dict? `:` type
2354 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser,
2355                                    OperationState &result) {
2356   auto &builder = parser.getBuilder();
2357   Type type;
2358   OpAsmParser::UnresolvedOperand ptr, cmp, val;
2359   if (parser.parseOperand(ptr) || parser.parseComma() ||
2360       parser.parseOperand(cmp) || parser.parseComma() ||
2361       parser.parseOperand(val) ||
2362       parseAtomicOrdering(parser, result, "success_ordering") ||
2363       parseAtomicOrdering(parser, result, "failure_ordering") ||
2364       parser.parseOptionalAttrDict(result.attributes) ||
2365       parser.parseColonType(type) ||
2366       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2367                             result.operands) ||
2368       parser.resolveOperand(cmp, type, result.operands) ||
2369       parser.resolveOperand(val, type, result.operands))
2370     return failure();
2371 
2372   auto boolType = IntegerType::get(builder.getContext(), 1);
2373   auto resultType =
2374       LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
2375   result.addTypes(resultType);
2376 
2377   return success();
2378 }
2379 
2380 LogicalResult AtomicCmpXchgOp::verify() {
2381   auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2382   if (!ptrType)
2383     return emitOpError("expected LLVM IR pointer type for operand #0");
2384   auto cmpType = getCmp().getType();
2385   auto valType = getVal().getType();
2386   if (cmpType != ptrType.getElementType() || cmpType != valType)
2387     return emitOpError("expected LLVM IR element type for operand #0 to "
2388                        "match type for all other operands");
2389   auto intType = valType.dyn_cast<IntegerType>();
2390   unsigned intBitWidth = intType ? intType.getWidth() : 0;
2391   if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
2392       intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
2393       !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
2394       !valType.isa<Float32Type>() && !valType.isa<Float64Type>())
2395     return emitOpError("unexpected LLVM IR type");
2396   if (getSuccessOrdering() < AtomicOrdering::monotonic ||
2397       getFailureOrdering() < AtomicOrdering::monotonic)
2398     return emitOpError("ordering must be at least 'monotonic'");
2399   if (getFailureOrdering() == AtomicOrdering::release ||
2400       getFailureOrdering() == AtomicOrdering::acq_rel)
2401     return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2402   return success();
2403 }
2404 
2405 //===----------------------------------------------------------------------===//
2406 // Printer, parser and verifier for LLVM::FenceOp.
2407 //===----------------------------------------------------------------------===//
2408 
2409 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
2410 // attribute-dict?
2411 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) {
2412   StringAttr sScope;
2413   StringRef syncscopeKeyword = "syncscope";
2414   if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
2415     if (parser.parseLParen() ||
2416         parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
2417         parser.parseRParen())
2418       return failure();
2419   } else {
2420     result.addAttribute(syncscopeKeyword,
2421                         parser.getBuilder().getStringAttr(""));
2422   }
2423   if (parseAtomicOrdering(parser, result, "ordering") ||
2424       parser.parseOptionalAttrDict(result.attributes))
2425     return failure();
2426   return success();
2427 }
2428 
2429 void FenceOp::print(OpAsmPrinter &p) {
2430   StringRef syncscopeKeyword = "syncscope";
2431   p << ' ';
2432   if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
2433     p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") ";
2434   p << stringifyAtomicOrdering(getOrdering());
2435 }
2436 
2437 LogicalResult FenceOp::verify() {
2438   if (getOrdering() == AtomicOrdering::not_atomic ||
2439       getOrdering() == AtomicOrdering::unordered ||
2440       getOrdering() == AtomicOrdering::monotonic)
2441     return emitOpError("can be given only acquire, release, acq_rel, "
2442                        "and seq_cst orderings");
2443   return success();
2444 }
2445 
2446 //===----------------------------------------------------------------------===//
2447 // Folder for LLVM::BitcastOp
2448 //===----------------------------------------------------------------------===//
2449 
2450 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
2451   // bitcast(x : T0, T0) -> x
2452   if (getArg().getType() == getType())
2453     return getArg();
2454   // bitcast(bitcast(x : T0, T1), T0) -> x
2455   if (auto prev = getArg().getDefiningOp<BitcastOp>())
2456     if (prev.getArg().getType() == getType())
2457       return prev.getArg();
2458   return {};
2459 }
2460 
2461 //===----------------------------------------------------------------------===//
2462 // Folder for LLVM::AddrSpaceCastOp
2463 //===----------------------------------------------------------------------===//
2464 
2465 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
2466   // addrcast(x : T0, T0) -> x
2467   if (getArg().getType() == getType())
2468     return getArg();
2469   // addrcast(addrcast(x : T0, T1), T0) -> x
2470   if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
2471     if (prev.getArg().getType() == getType())
2472       return prev.getArg();
2473   return {};
2474 }
2475 
2476 //===----------------------------------------------------------------------===//
2477 // Folder for LLVM::GEPOp
2478 //===----------------------------------------------------------------------===//
2479 
2480 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
2481   // gep %x:T, 0 -> %x
2482   if (getBase().getType() == getType() && getIndices().size() == 1 &&
2483       matchPattern(getIndices()[0], m_Zero()))
2484     return getBase();
2485   return {};
2486 }
2487 
2488 //===----------------------------------------------------------------------===//
2489 // LLVMDialect initialization, type parsing, and registration.
2490 //===----------------------------------------------------------------------===//
2491 
2492 void LLVMDialect::initialize() {
2493   addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>();
2494 
2495   // clang-format off
2496   addTypes<LLVMVoidType,
2497            LLVMPPCFP128Type,
2498            LLVMX86MMXType,
2499            LLVMTokenType,
2500            LLVMLabelType,
2501            LLVMMetadataType,
2502            LLVMFunctionType,
2503            LLVMPointerType,
2504            LLVMFixedVectorType,
2505            LLVMScalableVectorType,
2506            LLVMArrayType,
2507            LLVMStructType>();
2508   // clang-format on
2509   addOperations<
2510 #define GET_OP_LIST
2511 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2512       >();
2513 
2514   // Support unknown operations because not all LLVM operations are registered.
2515   allowUnknownOperations();
2516 }
2517 
2518 #define GET_OP_CLASSES
2519 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2520 
2521 /// Parse a type registered to this dialect.
2522 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
2523   return detail::parseType(parser);
2524 }
2525 
2526 /// Print a type registered to this dialect.
2527 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
2528   return detail::printType(type, os);
2529 }
2530 
2531 LogicalResult LLVMDialect::verifyDataLayoutString(
2532     StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
2533   llvm::Expected<llvm::DataLayout> maybeDataLayout =
2534       llvm::DataLayout::parse(descr);
2535   if (maybeDataLayout)
2536     return success();
2537 
2538   std::string message;
2539   llvm::raw_string_ostream messageStream(message);
2540   llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
2541   reportError("invalid data layout descriptor: " + messageStream.str());
2542   return failure();
2543 }
2544 
2545 /// Verify LLVM dialect attributes.
2546 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
2547                                                     NamedAttribute attr) {
2548   // If the `llvm.loop` attribute is present, enforce the following structure,
2549   // which the module translation can assume.
2550   if (attr.getName() == LLVMDialect::getLoopAttrName()) {
2551     auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>();
2552     if (!loopAttr)
2553       return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
2554                                << "' to be a dictionary attribute";
2555     Optional<NamedAttribute> parallelAccessGroup =
2556         loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
2557     if (parallelAccessGroup.hasValue()) {
2558       auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>();
2559       if (!accessGroups)
2560         return op->emitOpError()
2561                << "expected '" << LLVMDialect::getParallelAccessAttrName()
2562                << "' to be an array attribute";
2563       for (Attribute attr : accessGroups) {
2564         auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
2565         if (!accessGroupRef)
2566           return op->emitOpError()
2567                  << "expected '" << attr << "' to be a symbol reference";
2568         StringAttr metadataName = accessGroupRef.getRootReference();
2569         auto metadataOp =
2570             SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
2571                 op->getParentOp(), metadataName);
2572         if (!metadataOp)
2573           return op->emitOpError()
2574                  << "expected '" << attr << "' to reference a metadata op";
2575         StringAttr accessGroupName = accessGroupRef.getLeafReference();
2576         Operation *accessGroupOp =
2577             SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
2578         if (!accessGroupOp)
2579           return op->emitOpError()
2580                  << "expected '" << attr << "' to reference an access_group op";
2581       }
2582     }
2583 
2584     Optional<NamedAttribute> loopOptions =
2585         loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
2586     if (loopOptions.hasValue() &&
2587         !loopOptions->getValue().isa<LoopOptionsAttr>())
2588       return op->emitOpError()
2589              << "expected '" << LLVMDialect::getLoopOptionsAttrName()
2590              << "' to be a `loopopts` attribute";
2591   }
2592 
2593   if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2594     return op->emitOpError()
2595            << "'" << LLVM::LLVMDialect::getStructAttrsAttrName()
2596            << "' is permitted only in argument or result attributes";
2597   }
2598 
2599   // If the data layout attribute is present, it must use the LLVM data layout
2600   // syntax. Try parsing it and report errors in case of failure. Users of this
2601   // attribute may assume it is well-formed and can pass it to the (asserting)
2602   // llvm::DataLayout constructor.
2603   if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
2604     return success();
2605   if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>())
2606     return verifyDataLayoutString(
2607         stringAttr.getValue(),
2608         [op](const Twine &message) { op->emitOpError() << message.str(); });
2609 
2610   return op->emitOpError() << "expected '"
2611                            << LLVM::LLVMDialect::getDataLayoutAttrName()
2612                            << "' to be a string attribute";
2613 }
2614 
2615 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr,
2616                                             Type annotatedType) {
2617   auto structType = annotatedType.dyn_cast<LLVMStructType>();
2618   if (!structType) {
2619     const auto emitIncorrectAnnotatedType = [&op]() {
2620       return op->emitError()
2621              << "expected '" << LLVMDialect::getStructAttrsAttrName()
2622              << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'";
2623     };
2624     const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>();
2625     if (!ptrType)
2626       return emitIncorrectAnnotatedType();
2627     structType = ptrType.getElementType().dyn_cast<LLVMStructType>();
2628     if (!structType)
2629       return emitIncorrectAnnotatedType();
2630   }
2631 
2632   const auto arrAttrs = attr.dyn_cast<ArrayAttr>();
2633   if (!arrAttrs)
2634     return op->emitError() << "expected '"
2635                            << LLVMDialect::getStructAttrsAttrName()
2636                            << "' to be an array attribute";
2637 
2638   if (structType.getBody().size() != arrAttrs.size())
2639     return op->emitError()
2640            << "size of '" << LLVMDialect::getStructAttrsAttrName()
2641            << "' must match the size of the annotated '!llvm.struct'";
2642   return success();
2643 }
2644 
2645 static LogicalResult verifyFuncOpInterfaceStructAttr(
2646     Operation *op, Attribute attr,
2647     std::function<Type(FunctionOpInterface)> getAnnotatedType) {
2648   if (auto funcOp = dyn_cast<FunctionOpInterface>(op))
2649     return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp));
2650   return op->emitError() << "expected '"
2651                          << LLVMDialect::getStructAttrsAttrName()
2652                          << "' to be used on function-like operations";
2653 }
2654 
2655 /// Verify LLVMIR function argument attributes.
2656 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
2657                                                     unsigned regionIdx,
2658                                                     unsigned argIdx,
2659                                                     NamedAttribute argAttr) {
2660   // Check that llvm.noalias is a unit attribute.
2661   if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() &&
2662       !argAttr.getValue().isa<UnitAttr>())
2663     return op->emitError()
2664            << "expected llvm.noalias argument attribute to be a unit attribute";
2665   // Check that llvm.align is an integer attribute.
2666   if (argAttr.getName() == LLVMDialect::getAlignAttrName() &&
2667       !argAttr.getValue().isa<IntegerAttr>())
2668     return op->emitError()
2669            << "llvm.align argument attribute of non integer type";
2670   if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2671     return verifyFuncOpInterfaceStructAttr(
2672         op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) {
2673           return funcOp.getArgumentTypes()[argIdx];
2674         });
2675   }
2676   return success();
2677 }
2678 
2679 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
2680                                                        unsigned regionIdx,
2681                                                        unsigned resIdx,
2682                                                        NamedAttribute resAttr) {
2683   if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2684     return verifyFuncOpInterfaceStructAttr(
2685         op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) {
2686           return funcOp.getResultTypes()[resIdx];
2687         });
2688   }
2689   return success();
2690 }
2691 
2692 //===----------------------------------------------------------------------===//
2693 // Utility functions.
2694 //===----------------------------------------------------------------------===//
2695 
2696 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
2697                                      StringRef name, StringRef value,
2698                                      LLVM::Linkage linkage) {
2699   assert(builder.getInsertionBlock() &&
2700          builder.getInsertionBlock()->getParentOp() &&
2701          "expected builder to point to a block constrained in an op");
2702   auto module =
2703       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
2704   assert(module && "builder points to an op outside of a module");
2705 
2706   // Create the global at the entry of the module.
2707   OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
2708   MLIRContext *ctx = builder.getContext();
2709   auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
2710   auto global = moduleBuilder.create<LLVM::GlobalOp>(
2711       loc, type, /*isConstant=*/true, linkage, name,
2712       builder.getStringAttr(value), /*alignment=*/0);
2713 
2714   // Get the pointer to the first character in the global string.
2715   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
2716   Value cst0 = builder.create<LLVM::ConstantOp>(
2717       loc, IntegerType::get(ctx, 64),
2718       builder.getIntegerAttr(builder.getIndexType(), 0));
2719   return builder.create<LLVM::GEPOp>(
2720       loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
2721       ValueRange{cst0, cst0});
2722 }
2723 
2724 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
2725   return op->hasTrait<OpTrait::SymbolTable>() &&
2726          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
2727 }
2728 
2729 static constexpr const FastmathFlags fastmathFlagsList[] = {
2730     // clang-format off
2731     FastmathFlags::nnan,
2732     FastmathFlags::ninf,
2733     FastmathFlags::nsz,
2734     FastmathFlags::arcp,
2735     FastmathFlags::contract,
2736     FastmathFlags::afn,
2737     FastmathFlags::reassoc,
2738     FastmathFlags::fast,
2739     // clang-format on
2740 };
2741 
2742 void FMFAttr::print(AsmPrinter &printer) const {
2743   printer << "<";
2744   auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) {
2745     return bitEnumContains(this->getFlags(), flag);
2746   });
2747   llvm::interleaveComma(flags, printer,
2748                         [&](auto flag) { printer << stringifyEnum(flag); });
2749   printer << ">";
2750 }
2751 
2752 Attribute FMFAttr::parse(AsmParser &parser, Type type) {
2753   if (failed(parser.parseLess()))
2754     return {};
2755 
2756   FastmathFlags flags = {};
2757   if (failed(parser.parseOptionalGreater())) {
2758     do {
2759       StringRef elemName;
2760       if (failed(parser.parseKeyword(&elemName)))
2761         return {};
2762 
2763       auto elem = symbolizeFastmathFlags(elemName);
2764       if (!elem) {
2765         parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
2766             << elemName;
2767         return {};
2768       }
2769 
2770       flags = flags | *elem;
2771     } while (succeeded(parser.parseOptionalComma()));
2772 
2773     if (failed(parser.parseGreater()))
2774       return {};
2775   }
2776 
2777   return FMFAttr::get(parser.getContext(), flags);
2778 }
2779 
2780 void LinkageAttr::print(AsmPrinter &printer) const {
2781   printer << "<";
2782   if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
2783     printer << stringifyEnum(getLinkage());
2784   else
2785     printer << static_cast<uint64_t>(getLinkage());
2786   printer << ">";
2787 }
2788 
2789 Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
2790   StringRef elemName;
2791   if (parser.parseLess() || parser.parseKeyword(&elemName) ||
2792       parser.parseGreater())
2793     return {};
2794   auto elem = linkage::symbolizeLinkage(elemName);
2795   if (!elem) {
2796     parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
2797     return {};
2798   }
2799   Linkage linkage = *elem;
2800   return LinkageAttr::get(parser.getContext(), linkage);
2801 }
2802 
2803 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
2804     : options(attr.getOptions().begin(), attr.getOptions().end()) {}
2805 
2806 template <typename T>
2807 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
2808                                                           Optional<T> value) {
2809   auto option = llvm::find_if(
2810       options, [tag](auto option) { return option.first == tag; });
2811   if (option != options.end()) {
2812     if (value.hasValue())
2813       option->second = *value;
2814     else
2815       options.erase(option);
2816   } else {
2817     options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
2818   }
2819   return *this;
2820 }
2821 
2822 LoopOptionsAttrBuilder &
2823 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
2824   return setOption(LoopOptionCase::disable_licm, value);
2825 }
2826 
2827 /// Set the `interleave_count` option to the provided value. If no value
2828 /// is provided the option is deleted.
2829 LoopOptionsAttrBuilder &
2830 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
2831   return setOption(LoopOptionCase::interleave_count, count);
2832 }
2833 
2834 /// Set the `disable_unroll` option to the provided value. If no value
2835 /// is provided the option is deleted.
2836 LoopOptionsAttrBuilder &
2837 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
2838   return setOption(LoopOptionCase::disable_unroll, value);
2839 }
2840 
2841 /// Set the `disable_pipeline` option to the provided value. If no value
2842 /// is provided the option is deleted.
2843 LoopOptionsAttrBuilder &
2844 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
2845   return setOption(LoopOptionCase::disable_pipeline, value);
2846 }
2847 
2848 /// Set the `pipeline_initiation_interval` option to the provided value.
2849 /// If no value is provided the option is deleted.
2850 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
2851     Optional<uint64_t> count) {
2852   return setOption(LoopOptionCase::pipeline_initiation_interval, count);
2853 }
2854 
2855 template <typename T>
2856 static Optional<T>
2857 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
2858           LoopOptionCase option) {
2859   auto it =
2860       lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
2861         return optionPair.first < option;
2862       });
2863   if (it == options.end())
2864     return {};
2865   return static_cast<T>(it->second);
2866 }
2867 
2868 Optional<bool> LoopOptionsAttr::disableUnroll() {
2869   return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
2870 }
2871 
2872 Optional<bool> LoopOptionsAttr::disableLICM() {
2873   return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
2874 }
2875 
2876 Optional<int64_t> LoopOptionsAttr::interleaveCount() {
2877   return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
2878 }
2879 
2880 /// Build the LoopOptions Attribute from a sorted array of individual options.
2881 LoopOptionsAttr LoopOptionsAttr::get(
2882     MLIRContext *context,
2883     ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
2884   assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
2885          "LoopOptionsAttr ctor expects a sorted options array");
2886   return Base::get(context, sortedOptions);
2887 }
2888 
2889 /// Build the LoopOptions Attribute from a sorted array of individual options.
2890 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
2891                                      LoopOptionsAttrBuilder &optionBuilders) {
2892   llvm::sort(optionBuilders.options, llvm::less_first());
2893   return Base::get(context, optionBuilders.options);
2894 }
2895 
2896 void LoopOptionsAttr::print(AsmPrinter &printer) const {
2897   printer << "<";
2898   llvm::interleaveComma(getOptions(), printer, [&](auto option) {
2899     printer << stringifyEnum(option.first) << " = ";
2900     switch (option.first) {
2901     case LoopOptionCase::disable_licm:
2902     case LoopOptionCase::disable_unroll:
2903     case LoopOptionCase::disable_pipeline:
2904       printer << (option.second ? "true" : "false");
2905       break;
2906     case LoopOptionCase::interleave_count:
2907     case LoopOptionCase::pipeline_initiation_interval:
2908       printer << option.second;
2909       break;
2910     }
2911   });
2912   printer << ">";
2913 }
2914 
2915 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
2916   if (failed(parser.parseLess()))
2917     return {};
2918 
2919   SmallVector<std::pair<LoopOptionCase, int64_t>> options;
2920   llvm::SmallDenseSet<LoopOptionCase> seenOptions;
2921   do {
2922     StringRef optionName;
2923     if (parser.parseKeyword(&optionName))
2924       return {};
2925 
2926     auto option = symbolizeLoopOptionCase(optionName);
2927     if (!option) {
2928       parser.emitError(parser.getNameLoc(), "unknown loop option: ")
2929           << optionName;
2930       return {};
2931     }
2932     if (!seenOptions.insert(*option).second) {
2933       parser.emitError(parser.getNameLoc(), "loop option present twice");
2934       return {};
2935     }
2936     if (failed(parser.parseEqual()))
2937       return {};
2938 
2939     int64_t value;
2940     switch (*option) {
2941     case LoopOptionCase::disable_licm:
2942     case LoopOptionCase::disable_unroll:
2943     case LoopOptionCase::disable_pipeline:
2944       if (succeeded(parser.parseOptionalKeyword("true")))
2945         value = 1;
2946       else if (succeeded(parser.parseOptionalKeyword("false")))
2947         value = 0;
2948       else {
2949         parser.emitError(parser.getNameLoc(),
2950                          "expected boolean value 'true' or 'false'");
2951         return {};
2952       }
2953       break;
2954     case LoopOptionCase::interleave_count:
2955     case LoopOptionCase::pipeline_initiation_interval:
2956       if (failed(parser.parseInteger(value))) {
2957         parser.emitError(parser.getNameLoc(), "expected integer value");
2958         return {};
2959       }
2960       break;
2961     }
2962     options.push_back(std::make_pair(*option, value));
2963   } while (succeeded(parser.parseOptionalComma()));
2964   if (failed(parser.parseGreater()))
2965     return {};
2966 
2967   llvm::sort(options, llvm::less_first());
2968   return get(parser.getContext(), options);
2969 }
2970