1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "TypeDetail.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Matchers.h"
23 
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/Bitcode/BitcodeReader.h"
28 #include "llvm/Bitcode/BitcodeWriter.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Mutex.h"
33 #include "llvm/Support/SourceMgr.h"
34 
35 #include <numeric>
36 
37 using namespace mlir;
38 using namespace mlir::LLVM;
39 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
40 
41 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
42 
43 static constexpr const char kVolatileAttrName[] = "volatile_";
44 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
45 
46 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
47 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
48 #define GET_ATTRDEF_CLASSES
49 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
50 
51 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
52   SmallVector<NamedAttribute, 8> filteredAttrs(
53       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
54         if (attr.getName() == "fastmathFlags") {
55           auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
56           return defAttr != attr.getValue();
57         }
58         return true;
59       }));
60   return filteredAttrs;
61 }
62 
63 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
64                                     NamedAttrList &result) {
65   return parser.parseOptionalAttrDict(result);
66 }
67 
68 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
69                              DictionaryAttr attrs) {
70   printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
71 }
72 
73 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
74 /// fully defined llvm.func.
75 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
76                                          Operation *op,
77                                          SymbolTableCollection &symbolTable) {
78   StringRef name = symbol.getValue();
79   auto func =
80       symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
81   if (!func)
82     return op->emitOpError("'")
83            << name << "' does not reference a valid LLVM function";
84   if (func.isExternal())
85     return op->emitOpError("'") << name << "' does not have a definition";
86   return success();
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Printing/parsing for LLVM::CmpOp.
91 //===----------------------------------------------------------------------===//
92 
93 void ICmpOp::print(OpAsmPrinter &p) {
94   p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
95     << ", " << getOperand(1);
96   p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
97   p << " : " << getLhs().getType();
98 }
99 
100 void FCmpOp::print(OpAsmPrinter &p) {
101   p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
102     << ", " << getOperand(1);
103   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
104   p << " : " << getLhs().getType();
105 }
106 
107 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
108 //                 attribute-dict? `:` type
109 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
110 //                 attribute-dict? `:` type
111 template <typename CmpPredicateType>
112 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
113   Builder &builder = parser.getBuilder();
114 
115   StringAttr predicateAttr;
116   OpAsmParser::OperandType lhs, rhs;
117   Type type;
118   SMLoc predicateLoc, trailingTypeLoc;
119   if (parser.getCurrentLocation(&predicateLoc) ||
120       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
121       parser.parseOperand(lhs) || parser.parseComma() ||
122       parser.parseOperand(rhs) ||
123       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
124       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
125       parser.resolveOperand(lhs, type, result.operands) ||
126       parser.resolveOperand(rhs, type, result.operands))
127     return failure();
128 
129   // Replace the string attribute `predicate` with an integer attribute.
130   int64_t predicateValue = 0;
131   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
132     Optional<ICmpPredicate> predicate =
133         symbolizeICmpPredicate(predicateAttr.getValue());
134     if (!predicate)
135       return parser.emitError(predicateLoc)
136              << "'" << predicateAttr.getValue()
137              << "' is an incorrect value of the 'predicate' attribute";
138     predicateValue = static_cast<int64_t>(predicate.getValue());
139   } else {
140     Optional<FCmpPredicate> predicate =
141         symbolizeFCmpPredicate(predicateAttr.getValue());
142     if (!predicate)
143       return parser.emitError(predicateLoc)
144              << "'" << predicateAttr.getValue()
145              << "' is an incorrect value of the 'predicate' attribute";
146     predicateValue = static_cast<int64_t>(predicate.getValue());
147   }
148 
149   result.attributes.set("predicate",
150                         parser.getBuilder().getI64IntegerAttr(predicateValue));
151 
152   // The result type is either i1 or a vector type <? x i1> if the inputs are
153   // vectors.
154   Type resultType = IntegerType::get(builder.getContext(), 1);
155   if (!isCompatibleType(type))
156     return parser.emitError(trailingTypeLoc,
157                             "expected LLVM dialect-compatible type");
158   if (LLVM::isCompatibleVectorType(type)) {
159     if (LLVM::isScalableVectorType(type)) {
160       resultType = LLVM::getVectorType(
161           resultType, LLVM::getVectorNumElements(type).getKnownMinValue(),
162           /*isScalable=*/true);
163     } else {
164       resultType = LLVM::getVectorType(
165           resultType, LLVM::getVectorNumElements(type).getFixedValue(),
166           /*isScalable=*/false);
167     }
168   }
169 
170   result.addTypes({resultType});
171   return success();
172 }
173 
174 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
175   return parseCmpOp<ICmpPredicate>(parser, result);
176 }
177 
178 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
179   return parseCmpOp<FCmpPredicate>(parser, result);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // Printing/parsing for LLVM::AllocaOp.
184 //===----------------------------------------------------------------------===//
185 
186 void AllocaOp::print(OpAsmPrinter &p) {
187   auto elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType();
188 
189   auto funcTy =
190       FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
191 
192   p << ' ' << getArraySize() << " x " << elemTy;
193   if (getAlignment().hasValue() && *getAlignment() != 0)
194     p.printOptionalAttrDict((*this)->getAttrs());
195   else
196     p.printOptionalAttrDict((*this)->getAttrs(), {"alignment"});
197   p << " : " << funcTy;
198 }
199 
200 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
201 //                 `:` type `,` type
202 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
203   OpAsmParser::OperandType arraySize;
204   Type type, elemType;
205   SMLoc trailingTypeLoc;
206   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
207       parser.parseType(elemType) ||
208       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
209       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
210     return failure();
211 
212   Optional<NamedAttribute> alignmentAttr =
213       result.attributes.getNamed("alignment");
214   if (alignmentAttr.hasValue()) {
215     auto alignmentInt =
216         alignmentAttr.getValue().getValue().dyn_cast<IntegerAttr>();
217     if (!alignmentInt)
218       return parser.emitError(parser.getNameLoc(),
219                               "expected integer alignment");
220     if (alignmentInt.getValue().isNullValue())
221       result.attributes.erase("alignment");
222   }
223 
224   // Extract the result type from the trailing function type.
225   auto funcType = type.dyn_cast<FunctionType>();
226   if (!funcType || funcType.getNumInputs() != 1 ||
227       funcType.getNumResults() != 1)
228     return parser.emitError(
229         trailingTypeLoc,
230         "expected trailing function type with one argument and one result");
231 
232   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
233     return failure();
234 
235   result.addTypes({funcType.getResult(0)});
236   return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // LLVM::BrOp
241 //===----------------------------------------------------------------------===//
242 
243 Optional<MutableOperandRange>
244 BrOp::getMutableSuccessorOperands(unsigned index) {
245   assert(index == 0 && "invalid successor index");
246   return getDestOperandsMutable();
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // LLVM::CondBrOp
251 //===----------------------------------------------------------------------===//
252 
253 Optional<MutableOperandRange>
254 CondBrOp::getMutableSuccessorOperands(unsigned index) {
255   assert(index < getNumSuccessors() && "invalid successor index");
256   return index == 0 ? getTrueDestOperandsMutable()
257                     : getFalseDestOperandsMutable();
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // LLVM::SwitchOp
262 //===----------------------------------------------------------------------===//
263 
264 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
265                      Block *defaultDestination, ValueRange defaultOperands,
266                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
267                      ArrayRef<ValueRange> caseOperands,
268                      ArrayRef<int32_t> branchWeights) {
269   ElementsAttr caseValuesAttr;
270   if (!caseValues.empty())
271     caseValuesAttr = builder.getI32VectorAttr(caseValues);
272 
273   ElementsAttr weightsAttr;
274   if (!branchWeights.empty())
275     weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
276 
277   build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
278         weightsAttr, defaultDestination, caseDestinations);
279 }
280 
281 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
282 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
283 static ParseResult parseSwitchOpCases(
284     OpAsmParser &parser, Type flagType, ElementsAttr &caseValues,
285     SmallVectorImpl<Block *> &caseDestinations,
286     SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
287     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
288   SmallVector<APInt> values;
289   unsigned bitWidth = flagType.getIntOrFloatBitWidth();
290   do {
291     int64_t value = 0;
292     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
293     if (values.empty() && !integerParseResult.hasValue())
294       return success();
295 
296     if (!integerParseResult.hasValue() || integerParseResult.getValue())
297       return failure();
298     values.push_back(APInt(bitWidth, value));
299 
300     Block *destination;
301     SmallVector<OpAsmParser::OperandType> operands;
302     SmallVector<Type> operandTypes;
303     if (parser.parseColon() || parser.parseSuccessor(destination))
304       return failure();
305     if (!parser.parseOptionalLParen()) {
306       if (parser.parseRegionArgumentList(operands) ||
307           parser.parseColonTypeList(operandTypes) || parser.parseRParen())
308         return failure();
309     }
310     caseDestinations.push_back(destination);
311     caseOperands.emplace_back(operands);
312     caseOperandTypes.emplace_back(operandTypes);
313   } while (!parser.parseOptionalComma());
314 
315   ShapedType caseValueType =
316       VectorType::get(static_cast<int64_t>(values.size()), flagType);
317   caseValues = DenseIntElementsAttr::get(caseValueType, values);
318   return success();
319 }
320 
321 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
322                                ElementsAttr caseValues,
323                                SuccessorRange caseDestinations,
324                                OperandRangeRange caseOperands,
325                                const TypeRangeRange &caseOperandTypes) {
326   if (!caseValues)
327     return;
328 
329   size_t index = 0;
330   llvm::interleave(
331       llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
332       [&](auto i) {
333         p << "  ";
334         p << std::get<0>(i).getLimitedValue();
335         p << ": ";
336         p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
337       },
338       [&] {
339         p << ',';
340         p.printNewline();
341       });
342   p.printNewline();
343 }
344 
345 LogicalResult SwitchOp::verify() {
346   if ((!getCaseValues() && !getCaseDestinations().empty()) ||
347       (getCaseValues() &&
348        getCaseValues()->size() !=
349            static_cast<int64_t>(getCaseDestinations().size())))
350     return emitOpError("expects number of case values to match number of "
351                        "case destinations");
352   if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
353     return emitError("expects number of branch weights to match number of "
354                      "successors: ")
355            << getBranchWeights()->size() << " vs " << getNumSuccessors();
356   return success();
357 }
358 
359 Optional<MutableOperandRange>
360 SwitchOp::getMutableSuccessorOperands(unsigned index) {
361   assert(index < getNumSuccessors() && "invalid successor index");
362   return index == 0 ? getDefaultOperandsMutable()
363                     : getCaseOperandsMutable(index - 1);
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // Code for LLVM::GEPOp.
368 //===----------------------------------------------------------------------===//
369 
370 constexpr int GEPOp::kDynamicIndex;
371 
372 /// Populates `indices` with positions of GEP indices that would correspond to
373 /// LLVMStructTypes potentially nested in the given type. The type currently
374 /// visited gets `currentIndex` and LLVM container types are visited
375 /// recursively. The recursion is bounded and takes care of recursive types by
376 /// means of the `visited` set.
377 static void recordStructIndices(Type type, unsigned currentIndex,
378                                 SmallVectorImpl<unsigned> &indices,
379                                 SmallVectorImpl<unsigned> *structSizes,
380                                 SmallPtrSet<Type, 4> &visited) {
381   if (visited.contains(type))
382     return;
383 
384   visited.insert(type);
385 
386   llvm::TypeSwitch<Type>(type)
387       .Case<LLVMStructType>([&](LLVMStructType structType) {
388         indices.push_back(currentIndex);
389         if (structSizes)
390           structSizes->push_back(structType.getBody().size());
391         for (Type elementType : structType.getBody())
392           recordStructIndices(elementType, currentIndex + 1, indices,
393                               structSizes, visited);
394       })
395       .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
396             LLVMArrayType>([&](auto containerType) {
397         recordStructIndices(containerType.getElementType(), currentIndex + 1,
398                             indices, structSizes, visited);
399       });
400 }
401 
402 /// Populates `indices` with positions of GEP indices that correspond to
403 /// LLVMStructTypes potentially nested in the given `baseGEPType`, which must
404 /// be either an LLVMPointer type or a vector thereof. If `structSizes` is
405 /// provided, it is populated with sizes of the indexed structs for bounds
406 /// verification purposes.
407 static void
408 findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices,
409                        SmallVectorImpl<unsigned> *structSizes = nullptr) {
410   Type type = baseGEPType;
411   if (auto vectorType = type.dyn_cast<VectorType>())
412     type = vectorType.getElementType();
413   if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
414     type = scalableVectorType.getElementType();
415   if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
416     type = fixedVectorType.getElementType();
417 
418   Type pointeeType = type.cast<LLVMPointerType>().getElementType();
419   SmallPtrSet<Type, 4> visited;
420   recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes,
421                       visited);
422 }
423 
424 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
425                   Value basePtr, ValueRange operands,
426                   ArrayRef<NamedAttribute> attributes) {
427   build(builder, result, resultType, basePtr, operands,
428         SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex),
429         attributes);
430 }
431 
432 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
433                   Value basePtr, ValueRange indices,
434                   ArrayRef<int32_t> structIndices,
435                   ArrayRef<NamedAttribute> attributes) {
436   SmallVector<Value> remainingIndices;
437   SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
438                                             structIndices.end());
439   SmallVector<unsigned> structRelatedPositions;
440   findKnownStructIndices(basePtr.getType(), structRelatedPositions);
441 
442   SmallVector<unsigned> operandsToErase;
443   for (unsigned pos : structRelatedPositions) {
444     // GEP may not be indexing as deep as some structs are located.
445     if (pos >= structIndices.size())
446       continue;
447 
448     // If the index is already static, it's fine.
449     if (structIndices[pos] != kDynamicIndex)
450       continue;
451 
452     // Find the corresponding operand.
453     unsigned operandPos =
454         std::count(structIndices.begin(), std::next(structIndices.begin(), pos),
455                    kDynamicIndex);
456 
457     // Extract the constant value from the operand and put it into the attribute
458     // instead.
459     APInt staticIndexValue;
460     bool matched =
461         matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue));
462     (void)matched;
463     assert(matched && "index into a struct must be a constant");
464     assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) &&
465            "struct index underflows 32-bit integer");
466     assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) &&
467            "struct index overflows 32-bit integer");
468     auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
469     updatedStructIndices[pos] = staticIndex;
470     operandsToErase.push_back(operandPos);
471   }
472 
473   for (unsigned i = 0, e = indices.size(); i < e; ++i) {
474     if (!llvm::is_contained(operandsToErase, i))
475       remainingIndices.push_back(indices[i]);
476   }
477 
478   assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
479                                         updatedStructIndices, kDynamicIndex)) &&
480          "expected as many index operands as dynamic index attr elements");
481 
482   result.addTypes(resultType);
483   result.addAttributes(attributes);
484   result.addAttribute("structIndices",
485                       builder.getI32TensorAttr(updatedStructIndices));
486   result.addOperands(basePtr);
487   result.addOperands(remainingIndices);
488 }
489 
490 static ParseResult
491 parseGEPIndices(OpAsmParser &parser,
492                 SmallVectorImpl<OpAsmParser::OperandType> &indices,
493                 DenseIntElementsAttr &structIndices) {
494   SmallVector<int32_t> constantIndices;
495   do {
496     int32_t constantIndex;
497     OptionalParseResult parsedInteger =
498         parser.parseOptionalInteger(constantIndex);
499     if (parsedInteger.hasValue()) {
500       if (failed(parsedInteger.getValue()))
501         return failure();
502       constantIndices.push_back(constantIndex);
503       continue;
504     }
505 
506     constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
507     if (failed(parser.parseOperand(indices.emplace_back())))
508       return failure();
509   } while (succeeded(parser.parseOptionalComma()));
510 
511   structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
512   return success();
513 }
514 
515 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
516                             OperandRange indices,
517                             DenseIntElementsAttr structIndices) {
518   unsigned operandIdx = 0;
519   llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
520                         [&](int32_t cst) {
521                           if (cst == LLVM::GEPOp::kDynamicIndex)
522                             printer.printOperand(indices[operandIdx++]);
523                           else
524                             printer << cst;
525                         });
526 }
527 
528 LogicalResult LLVM::GEPOp::verify() {
529   SmallVector<unsigned> indices;
530   SmallVector<unsigned> structSizes;
531   findKnownStructIndices(getBase().getType(), indices, &structSizes);
532   DenseIntElementsAttr structIndices = getStructIndices();
533   for (unsigned i : llvm::seq<unsigned>(0, indices.size())) {
534     unsigned index = indices[i];
535     // GEP may not be indexing as deep as some structs nested in the type.
536     if (index >= structIndices.getNumElements())
537       continue;
538 
539     int32_t staticIndex = structIndices.getValues<int32_t>()[index];
540     if (staticIndex == LLVM::GEPOp::kDynamicIndex)
541       return emitOpError() << "expected index " << index
542                            << " indexing a struct to be constant";
543     if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
544       return emitOpError() << "index " << index
545                            << " indexing a struct is out of bounds";
546   }
547   return success();
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // Builder, printer and parser for for LLVM::LoadOp.
552 //===----------------------------------------------------------------------===//
553 
554 LogicalResult verifySymbolAttribute(
555     Operation *op, StringRef attributeName,
556     llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
557         verifySymbolType) {
558   if (Attribute attribute = op->getAttr(attributeName)) {
559     // The attribute is already verified to be a symbol ref array attribute via
560     // a constraint in the operation definition.
561     for (SymbolRefAttr symbolRef :
562          attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
563       StringAttr metadataName = symbolRef.getRootReference();
564       StringAttr symbolName = symbolRef.getLeafReference();
565       // We want @metadata::@symbol, not just @symbol
566       if (metadataName == symbolName) {
567         return op->emitOpError() << "expected '" << symbolRef
568                                  << "' to specify a fully qualified reference";
569       }
570       auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
571           op->getParentOp(), metadataName);
572       if (!metadataOp)
573         return op->emitOpError()
574                << "expected '" << symbolRef << "' to reference a metadata op";
575       Operation *symbolOp =
576           SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
577       if (!symbolOp)
578         return op->emitOpError()
579                << "expected '" << symbolRef << "' to be a valid reference";
580       if (failed(verifySymbolType(symbolOp, symbolRef))) {
581         return failure();
582       }
583     }
584   }
585   return success();
586 }
587 
588 // Verifies that metadata ops are wired up properly.
589 template <typename OpTy>
590 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
591   auto verifySymbolType = [op](Operation *symbolOp,
592                                SymbolRefAttr symbolRef) -> LogicalResult {
593     if (!isa<OpTy>(symbolOp)) {
594       return op->emitOpError()
595              << "expected '" << symbolRef << "' to resolve to a "
596              << OpTy::getOperationName();
597     }
598     return success();
599   };
600 
601   return verifySymbolAttribute(op, attributeName, verifySymbolType);
602 }
603 
604 static LogicalResult verifyMemoryOpMetadata(Operation *op) {
605   // access_groups
606   if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
607           op, LLVMDialect::getAccessGroupsAttrName())))
608     return failure();
609 
610   // alias_scopes
611   if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
612           op, LLVMDialect::getAliasScopesAttrName())))
613     return failure();
614 
615   // noalias_scopes
616   if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
617           op, LLVMDialect::getNoAliasScopesAttrName())))
618     return failure();
619 
620   return success();
621 }
622 
623 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
624 
625 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
626                    Value addr, unsigned alignment, bool isVolatile,
627                    bool isNonTemporal) {
628   result.addOperands(addr);
629   result.addTypes(t);
630   if (isVolatile)
631     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
632   if (isNonTemporal)
633     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
634   if (alignment != 0)
635     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
636 }
637 
638 void LoadOp::print(OpAsmPrinter &p) {
639   p << ' ';
640   if (getVolatile_())
641     p << "volatile ";
642   p << getAddr();
643   p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
644   p << " : " << getAddr().getType();
645 }
646 
647 // Extract the pointee type from the LLVM pointer type wrapped in MLIR.  Return
648 // the resulting type wrapped in MLIR, or nullptr on error.
649 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
650                                     SMLoc trailingTypeLoc) {
651   auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
652   if (!llvmTy)
653     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
654            nullptr;
655   return llvmTy.getElementType();
656 }
657 
658 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
659 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
660   OpAsmParser::OperandType addr;
661   Type type;
662   SMLoc trailingTypeLoc;
663 
664   if (succeeded(parser.parseOptionalKeyword("volatile")))
665     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
666 
667   if (parser.parseOperand(addr) ||
668       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
669       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
670       parser.resolveOperand(addr, type, result.operands))
671     return failure();
672 
673   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
674 
675   result.addTypes(elemTy);
676   return success();
677 }
678 
679 //===----------------------------------------------------------------------===//
680 // Builder, printer and parser for LLVM::StoreOp.
681 //===----------------------------------------------------------------------===//
682 
683 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
684 
685 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
686                     Value addr, unsigned alignment, bool isVolatile,
687                     bool isNonTemporal) {
688   result.addOperands({value, addr});
689   result.addTypes({});
690   if (isVolatile)
691     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
692   if (isNonTemporal)
693     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
694   if (alignment != 0)
695     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
696 }
697 
698 void StoreOp::print(OpAsmPrinter &p) {
699   p << ' ';
700   if (getVolatile_())
701     p << "volatile ";
702   p << getValue() << ", " << getAddr();
703   p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
704   p << " : " << getAddr().getType();
705 }
706 
707 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
708 //                 attribute-dict? `:` type
709 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
710   OpAsmParser::OperandType addr, value;
711   Type type;
712   SMLoc trailingTypeLoc;
713 
714   if (succeeded(parser.parseOptionalKeyword("volatile")))
715     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
716 
717   if (parser.parseOperand(value) || parser.parseComma() ||
718       parser.parseOperand(addr) ||
719       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
720       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
721     return failure();
722 
723   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
724   if (!elemTy)
725     return failure();
726 
727   if (parser.resolveOperand(value, elemTy, result.operands) ||
728       parser.resolveOperand(addr, type, result.operands))
729     return failure();
730 
731   return success();
732 }
733 
734 ///===---------------------------------------------------------------------===//
735 /// LLVM::InvokeOp
736 ///===---------------------------------------------------------------------===//
737 
738 Optional<MutableOperandRange>
739 InvokeOp::getMutableSuccessorOperands(unsigned index) {
740   assert(index < getNumSuccessors() && "invalid successor index");
741   return index == 0 ? getNormalDestOperandsMutable()
742                     : getUnwindDestOperandsMutable();
743 }
744 
745 LogicalResult InvokeOp::verify() {
746   if (getNumResults() > 1)
747     return emitOpError("must have 0 or 1 result");
748 
749   Block *unwindDest = getUnwindDest();
750   if (unwindDest->empty())
751     return emitError("must have at least one operation in unwind destination");
752 
753   // In unwind destination, first operation must be LandingpadOp
754   if (!isa<LandingpadOp>(unwindDest->front()))
755     return emitError("first operation in unwind destination should be a "
756                      "llvm.landingpad operation");
757 
758   return success();
759 }
760 
761 void InvokeOp::print(OpAsmPrinter &p) {
762   auto callee = getCallee();
763   bool isDirect = callee.hasValue();
764 
765   p << ' ';
766 
767   // Either function name or pointer
768   if (isDirect)
769     p.printSymbolName(callee.getValue());
770   else
771     p << getOperand(0);
772 
773   p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
774   p << " to ";
775   p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
776   p << " unwind ";
777   p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
778 
779   p.printOptionalAttrDict((*this)->getAttrs(),
780                           {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
781   p << " : ";
782   p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
783                         getResultTypes());
784 }
785 
786 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
787 ///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
788 ///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
789 ///                  attribute-dict? `:` function-type
790 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
791   SmallVector<OpAsmParser::OperandType, 8> operands;
792   FunctionType funcType;
793   SymbolRefAttr funcAttr;
794   SMLoc trailingTypeLoc;
795   Block *normalDest, *unwindDest;
796   SmallVector<Value, 4> normalOperands, unwindOperands;
797   Builder &builder = parser.getBuilder();
798 
799   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
800   // case of an indirect call, there will be 1 operand before `(`.  In case of a
801   // direct call, there will be no operands and the parser will stop at the
802   // function identifier without complaining.
803   if (parser.parseOperandList(operands))
804     return failure();
805   bool isDirect = operands.empty();
806 
807   // Optionally parse a function identifier.
808   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
809     return failure();
810 
811   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
812       parser.parseKeyword("to") ||
813       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
814       parser.parseKeyword("unwind") ||
815       parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
816       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
817       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
818     return failure();
819 
820   if (isDirect) {
821     // Make sure types match.
822     if (parser.resolveOperands(operands, funcType.getInputs(),
823                                parser.getNameLoc(), result.operands))
824       return failure();
825     result.addTypes(funcType.getResults());
826   } else {
827     // Construct the LLVM IR Dialect function type that the first operand
828     // should match.
829     if (funcType.getNumResults() > 1)
830       return parser.emitError(trailingTypeLoc,
831                               "expected function with 0 or 1 result");
832 
833     Type llvmResultType;
834     if (funcType.getNumResults() == 0) {
835       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
836     } else {
837       llvmResultType = funcType.getResult(0);
838       if (!isCompatibleType(llvmResultType))
839         return parser.emitError(trailingTypeLoc,
840                                 "expected result to have LLVM type");
841     }
842 
843     SmallVector<Type, 8> argTypes;
844     argTypes.reserve(funcType.getNumInputs());
845     for (Type ty : funcType.getInputs()) {
846       if (isCompatibleType(ty))
847         argTypes.push_back(ty);
848       else
849         return parser.emitError(trailingTypeLoc,
850                                 "expected LLVM types as inputs");
851     }
852 
853     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
854     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
855 
856     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
857 
858     // Make sure that the first operand (indirect callee) matches the wrapped
859     // LLVM IR function type, and that the types of the other call operands
860     // match the types of the function arguments.
861     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
862         parser.resolveOperands(funcArguments, funcType.getInputs(),
863                                parser.getNameLoc(), result.operands))
864       return failure();
865 
866     result.addTypes(llvmResultType);
867   }
868   result.addSuccessors({normalDest, unwindDest});
869   result.addOperands(normalOperands);
870   result.addOperands(unwindOperands);
871 
872   result.addAttribute(
873       InvokeOp::getOperandSegmentSizeAttr(),
874       builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
875                                 static_cast<int32_t>(normalOperands.size()),
876                                 static_cast<int32_t>(unwindOperands.size())}));
877   return success();
878 }
879 
880 ///===----------------------------------------------------------------------===//
881 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
882 ///===----------------------------------------------------------------------===//
883 
884 LogicalResult LandingpadOp::verify() {
885   Value value;
886   if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
887     if (!func.getPersonality().hasValue())
888       return emitError(
889           "llvm.landingpad needs to be in a function with a personality");
890   }
891 
892   if (!getCleanup() && getOperands().empty())
893     return emitError("landingpad instruction expects at least one clause or "
894                      "cleanup attribute");
895 
896   for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
897     value = getOperand(idx);
898     bool isFilter = value.getType().isa<LLVMArrayType>();
899     if (isFilter) {
900       // FIXME: Verify filter clauses when arrays are appropriately handled
901     } else {
902       // catch - global addresses only.
903       // Bitcast ops should have global addresses as their args.
904       if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
905         if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
906           continue;
907         return emitError("constant clauses expected").attachNote(bcOp.getLoc())
908                << "global addresses expected as operand to "
909                   "bitcast used in clauses for landingpad";
910       }
911       // NullOp and AddressOfOp allowed
912       if (value.getDefiningOp<NullOp>())
913         continue;
914       if (value.getDefiningOp<AddressOfOp>())
915         continue;
916       return emitError("clause #")
917              << idx << " is not a known constant - null, addressof, bitcast";
918     }
919   }
920   return success();
921 }
922 
923 void LandingpadOp::print(OpAsmPrinter &p) {
924   p << (getCleanup() ? " cleanup " : " ");
925 
926   // Clauses
927   for (auto value : getOperands()) {
928     // Similar to llvm - if clause is an array type then it is filter
929     // clause else catch clause
930     bool isArrayTy = value.getType().isa<LLVMArrayType>();
931     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
932       << value.getType() << ") ";
933   }
934 
935   p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
936 
937   p << ": " << getType();
938 }
939 
940 /// <operation> ::= `llvm.landingpad` `cleanup`?
941 ///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
942 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
943   // Check for cleanup
944   if (succeeded(parser.parseOptionalKeyword("cleanup")))
945     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
946 
947   // Parse clauses with types
948   while (succeeded(parser.parseOptionalLParen()) &&
949          (succeeded(parser.parseOptionalKeyword("filter")) ||
950           succeeded(parser.parseOptionalKeyword("catch")))) {
951     OpAsmParser::OperandType operand;
952     Type ty;
953     if (parser.parseOperand(operand) || parser.parseColon() ||
954         parser.parseType(ty) ||
955         parser.resolveOperand(operand, ty, result.operands) ||
956         parser.parseRParen())
957       return failure();
958   }
959 
960   Type type;
961   if (parser.parseColon() || parser.parseType(type))
962     return failure();
963 
964   result.addTypes(type);
965   return success();
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // Verifying/Printing/parsing for LLVM::CallOp.
970 //===----------------------------------------------------------------------===//
971 
972 LogicalResult CallOp::verify() {
973   if (getNumResults() > 1)
974     return emitOpError("must have 0 or 1 result");
975 
976   // Type for the callee, we'll get it differently depending if it is a direct
977   // or indirect call.
978   Type fnType;
979 
980   bool isIndirect = false;
981 
982   // If this is an indirect call, the callee attribute is missing.
983   FlatSymbolRefAttr calleeName = getCalleeAttr();
984   if (!calleeName) {
985     isIndirect = true;
986     if (!getNumOperands())
987       return emitOpError(
988           "must have either a `callee` attribute or at least an operand");
989     auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
990     if (!ptrType)
991       return emitOpError("indirect call expects a pointer as callee: ")
992              << ptrType;
993     fnType = ptrType.getElementType();
994   } else {
995     Operation *callee =
996         SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
997     if (!callee)
998       return emitOpError()
999              << "'" << calleeName.getValue()
1000              << "' does not reference a symbol in the current scope";
1001     auto fn = dyn_cast<LLVMFuncOp>(callee);
1002     if (!fn)
1003       return emitOpError() << "'" << calleeName.getValue()
1004                            << "' does not reference a valid LLVM function";
1005 
1006     fnType = fn.getType();
1007   }
1008 
1009   LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
1010   if (!funcType)
1011     return emitOpError("callee does not have a functional type: ") << fnType;
1012 
1013   // Verify that the operand and result types match the callee.
1014 
1015   if (!funcType.isVarArg() &&
1016       funcType.getNumParams() != (getNumOperands() - isIndirect))
1017     return emitOpError() << "incorrect number of operands ("
1018                          << (getNumOperands() - isIndirect)
1019                          << ") for callee (expecting: "
1020                          << funcType.getNumParams() << ")";
1021 
1022   if (funcType.getNumParams() > (getNumOperands() - isIndirect))
1023     return emitOpError() << "incorrect number of operands ("
1024                          << (getNumOperands() - isIndirect)
1025                          << ") for varargs callee (expecting at least: "
1026                          << funcType.getNumParams() << ")";
1027 
1028   for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1029     if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1030       return emitOpError() << "operand type mismatch for operand " << i << ": "
1031                            << getOperand(i + isIndirect).getType()
1032                            << " != " << funcType.getParamType(i);
1033 
1034   if (getNumResults() == 0 &&
1035       !funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1036     return emitOpError() << "expected function call to produce a value";
1037 
1038   if (getNumResults() != 0 &&
1039       funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1040     return emitOpError()
1041            << "calling function with void result must not produce values";
1042 
1043   if (getNumResults() > 1)
1044     return emitOpError()
1045            << "expected LLVM function call to produce 0 or 1 result";
1046 
1047   if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
1048     return emitOpError() << "result type mismatch: " << getResult(0).getType()
1049                          << " != " << funcType.getReturnType();
1050 
1051   return success();
1052 }
1053 
1054 void CallOp::print(OpAsmPrinter &p) {
1055   auto callee = getCallee();
1056   bool isDirect = callee.hasValue();
1057 
1058   // Print the direct callee if present as a function attribute, or an indirect
1059   // callee (first operand) otherwise.
1060   p << ' ';
1061   if (isDirect)
1062     p.printSymbolName(callee.getValue());
1063   else
1064     p << getOperand(0);
1065 
1066   auto args = getOperands().drop_front(isDirect ? 0 : 1);
1067   p << '(' << args << ')';
1068   p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
1069 
1070   // Reconstruct the function MLIR function type from operand and result types.
1071   p << " : ";
1072   p.printFunctionalType(args.getTypes(), getResultTypes());
1073 }
1074 
1075 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
1076 //                 attribute-dict? `:` function-type
1077 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1078   SmallVector<OpAsmParser::OperandType, 8> operands;
1079   Type type;
1080   SymbolRefAttr funcAttr;
1081   SMLoc trailingTypeLoc;
1082 
1083   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
1084   // case of an indirect call, there will be 1 operand before `(`.  In case of a
1085   // direct call, there will be no operands and the parser will stop at the
1086   // function identifier without complaining.
1087   if (parser.parseOperandList(operands))
1088     return failure();
1089   bool isDirect = operands.empty();
1090 
1091   // Optionally parse a function identifier.
1092   if (isDirect)
1093     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1094       return failure();
1095 
1096   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1097       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1098       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
1099     return failure();
1100 
1101   auto funcType = type.dyn_cast<FunctionType>();
1102   if (!funcType)
1103     return parser.emitError(trailingTypeLoc, "expected function type");
1104   if (funcType.getNumResults() > 1)
1105     return parser.emitError(trailingTypeLoc,
1106                             "expected function with 0 or 1 result");
1107   if (isDirect) {
1108     // Make sure types match.
1109     if (parser.resolveOperands(operands, funcType.getInputs(),
1110                                parser.getNameLoc(), result.operands))
1111       return failure();
1112     if (funcType.getNumResults() != 0 &&
1113         !funcType.getResult(0).isa<LLVM::LLVMVoidType>())
1114       result.addTypes(funcType.getResults());
1115   } else {
1116     Builder &builder = parser.getBuilder();
1117     Type llvmResultType;
1118     if (funcType.getNumResults() == 0) {
1119       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
1120     } else {
1121       llvmResultType = funcType.getResult(0);
1122       if (!isCompatibleType(llvmResultType))
1123         return parser.emitError(trailingTypeLoc,
1124                                 "expected result to have LLVM type");
1125     }
1126 
1127     SmallVector<Type, 8> argTypes;
1128     argTypes.reserve(funcType.getNumInputs());
1129     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
1130       auto argType = funcType.getInput(i);
1131       if (!isCompatibleType(argType))
1132         return parser.emitError(trailingTypeLoc,
1133                                 "expected LLVM types as inputs");
1134       argTypes.push_back(argType);
1135     }
1136     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
1137     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
1138 
1139     auto funcArguments =
1140         ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
1141 
1142     // Make sure that the first operand (indirect callee) matches the wrapped
1143     // LLVM IR function type, and that the types of the other call operands
1144     // match the types of the function arguments.
1145     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
1146         parser.resolveOperands(funcArguments, funcType.getInputs(),
1147                                parser.getNameLoc(), result.operands))
1148       return failure();
1149 
1150     if (!llvmResultType.isa<LLVM::LLVMVoidType>())
1151       result.addTypes(llvmResultType);
1152   }
1153 
1154   return success();
1155 }
1156 
1157 //===----------------------------------------------------------------------===//
1158 // Printing/parsing for LLVM::ExtractElementOp.
1159 //===----------------------------------------------------------------------===//
1160 // Expects vector to be of wrapped LLVM vector type and position to be of
1161 // wrapped LLVM i32 type.
1162 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
1163                                    Value vector, Value position,
1164                                    ArrayRef<NamedAttribute> attrs) {
1165   auto vectorType = vector.getType();
1166   auto llvmType = LLVM::getVectorElementType(vectorType);
1167   build(b, result, llvmType, vector, position);
1168   result.addAttributes(attrs);
1169 }
1170 
1171 void ExtractElementOp::print(OpAsmPrinter &p) {
1172   p << ' ' << getVector() << "[" << getPosition() << " : "
1173     << getPosition().getType() << "]";
1174   p.printOptionalAttrDict((*this)->getAttrs());
1175   p << " : " << getVector().getType();
1176 }
1177 
1178 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
1179 //                 attribute-dict? `:` type
1180 ParseResult ExtractElementOp::parse(OpAsmParser &parser,
1181                                     OperationState &result) {
1182   SMLoc loc;
1183   OpAsmParser::OperandType vector, position;
1184   Type type, positionType;
1185   if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
1186       parser.parseLSquare() || parser.parseOperand(position) ||
1187       parser.parseColonType(positionType) || parser.parseRSquare() ||
1188       parser.parseOptionalAttrDict(result.attributes) ||
1189       parser.parseColonType(type) ||
1190       parser.resolveOperand(vector, type, result.operands) ||
1191       parser.resolveOperand(position, positionType, result.operands))
1192     return failure();
1193   if (!LLVM::isCompatibleVectorType(type))
1194     return parser.emitError(
1195         loc, "expected LLVM dialect-compatible vector type for operand #1");
1196   result.addTypes(LLVM::getVectorElementType(type));
1197   return success();
1198 }
1199 
1200 LogicalResult ExtractElementOp::verify() {
1201   Type vectorType = getVector().getType();
1202   if (!LLVM::isCompatibleVectorType(vectorType))
1203     return emitOpError("expected LLVM dialect-compatible vector type for "
1204                        "operand #1, got")
1205            << vectorType;
1206   Type valueType = LLVM::getVectorElementType(vectorType);
1207   if (valueType != getRes().getType())
1208     return emitOpError() << "Type mismatch: extracting from " << vectorType
1209                          << " should produce " << valueType
1210                          << " but this op returns " << getRes().getType();
1211   return success();
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // Printing/parsing for LLVM::ExtractValueOp.
1216 //===----------------------------------------------------------------------===//
1217 
1218 void ExtractValueOp::print(OpAsmPrinter &p) {
1219   p << ' ' << getContainer() << getPosition();
1220   p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1221   p << " : " << getContainer().getType();
1222 }
1223 
1224 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1225 // `containerType`.  Position is an integer array attribute where each value
1226 // is a zero-based position of the element in the aggregate type.  Return the
1227 // resulting type wrapped in MLIR, or nullptr on error.
1228 static Type getInsertExtractValueElementType(OpAsmParser &parser,
1229                                              Type containerType,
1230                                              ArrayAttr positionAttr,
1231                                              SMLoc attributeLoc,
1232                                              SMLoc typeLoc) {
1233   Type llvmType = containerType;
1234   if (!isCompatibleType(containerType))
1235     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
1236 
1237   // Infer the element type from the structure type: iteratively step inside the
1238   // type by taking the element type, indexed by the position attribute for
1239   // structures.  Check the position index before accessing, it is supposed to
1240   // be in bounds.
1241   for (Attribute subAttr : positionAttr) {
1242     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1243     if (!positionElementAttr)
1244       return parser.emitError(attributeLoc,
1245                               "expected an array of integer literals"),
1246              nullptr;
1247     int position = positionElementAttr.getInt();
1248     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1249       if (position < 0 ||
1250           static_cast<unsigned>(position) >= arrayType.getNumElements())
1251         return parser.emitError(attributeLoc, "position out of bounds"),
1252                nullptr;
1253       llvmType = arrayType.getElementType();
1254     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1255       if (position < 0 ||
1256           static_cast<unsigned>(position) >= structType.getBody().size())
1257         return parser.emitError(attributeLoc, "position out of bounds"),
1258                nullptr;
1259       llvmType = structType.getBody()[position];
1260     } else {
1261       return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
1262              nullptr;
1263     }
1264   }
1265   return llvmType;
1266 }
1267 
1268 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1269 // `containerType`. Returns null on failure.
1270 static Type getInsertExtractValueElementType(Type containerType,
1271                                              ArrayAttr positionAttr,
1272                                              Operation *op) {
1273   Type llvmType = containerType;
1274   if (!isCompatibleType(containerType)) {
1275     op->emitError("expected LLVM IR Dialect type, got ") << containerType;
1276     return {};
1277   }
1278 
1279   // Infer the element type from the structure type: iteratively step inside the
1280   // type by taking the element type, indexed by the position attribute for
1281   // structures.  Check the position index before accessing, it is supposed to
1282   // be in bounds.
1283   for (Attribute subAttr : positionAttr) {
1284     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1285     if (!positionElementAttr) {
1286       op->emitOpError("expected an array of integer literals, got: ")
1287           << subAttr;
1288       return {};
1289     }
1290     int position = positionElementAttr.getInt();
1291     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1292       if (position < 0 ||
1293           static_cast<unsigned>(position) >= arrayType.getNumElements()) {
1294         op->emitOpError("position out of bounds: ") << position;
1295         return {};
1296       }
1297       llvmType = arrayType.getElementType();
1298     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1299       if (position < 0 ||
1300           static_cast<unsigned>(position) >= structType.getBody().size()) {
1301         op->emitOpError("position out of bounds") << position;
1302         return {};
1303       }
1304       llvmType = structType.getBody()[position];
1305     } else {
1306       op->emitOpError("expected LLVM IR structure/array type, got: ")
1307           << llvmType;
1308       return {};
1309     }
1310   }
1311   return llvmType;
1312 }
1313 
1314 // <operation> ::= `llvm.extractvalue` ssa-use
1315 //                 `[` integer-literal (`,` integer-literal)* `]`
1316 //                 attribute-dict? `:` type
1317 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) {
1318   OpAsmParser::OperandType container;
1319   Type containerType;
1320   ArrayAttr positionAttr;
1321   SMLoc attributeLoc, trailingTypeLoc;
1322 
1323   if (parser.parseOperand(container) ||
1324       parser.getCurrentLocation(&attributeLoc) ||
1325       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1326       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1327       parser.getCurrentLocation(&trailingTypeLoc) ||
1328       parser.parseType(containerType) ||
1329       parser.resolveOperand(container, containerType, result.operands))
1330     return failure();
1331 
1332   auto elementType = getInsertExtractValueElementType(
1333       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1334   if (!elementType)
1335     return failure();
1336 
1337   result.addTypes(elementType);
1338   return success();
1339 }
1340 
1341 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
1342   auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1343   OpFoldResult result = {};
1344   while (insertValueOp) {
1345     if (getPosition() == insertValueOp.getPosition())
1346       return insertValueOp.getValue();
1347     unsigned min =
1348         std::min(getPosition().size(), insertValueOp.getPosition().size());
1349     // If one is fully prefix of the other, stop propagating back as it will
1350     // miss dependencies. For instance, %3 should not fold to %f0 in the
1351     // following example:
1352     // ```
1353     //   %1 = llvm.insertvalue %f0, %0[0, 0] :
1354     //     !llvm.array<4 x !llvm.array<4xf32>>
1355     //   %2 = llvm.insertvalue %arr, %1[0] :
1356     //     !llvm.array<4 x !llvm.array<4xf32>>
1357     //   %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>>
1358     // ```
1359     if (getPosition().getValue().take_front(min) ==
1360         insertValueOp.getPosition().getValue().take_front(min))
1361       return result;
1362 
1363     // If neither a prefix, nor the exact position, we can extract out of the
1364     // value being inserted into. Moreover, we can try again if that operand
1365     // is itself an insertvalue expression.
1366     getContainerMutable().assign(insertValueOp.getContainer());
1367     result = getResult();
1368     insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1369   }
1370   return result;
1371 }
1372 
1373 LogicalResult ExtractValueOp::verify() {
1374   Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1375                                                     getPositionAttr(), *this);
1376   if (!valueType)
1377     return failure();
1378 
1379   if (getRes().getType() != valueType)
1380     return emitOpError() << "Type mismatch: extracting from "
1381                          << getContainer().getType() << " should produce "
1382                          << valueType << " but this op returns "
1383                          << getRes().getType();
1384   return success();
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // Printing/parsing for LLVM::InsertElementOp.
1389 //===----------------------------------------------------------------------===//
1390 
1391 void InsertElementOp::print(OpAsmPrinter &p) {
1392   p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : "
1393     << getPosition().getType() << "]";
1394   p.printOptionalAttrDict((*this)->getAttrs());
1395   p << " : " << getVector().getType();
1396 }
1397 
1398 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
1399 //                 attribute-dict? `:` type
1400 ParseResult InsertElementOp::parse(OpAsmParser &parser,
1401                                    OperationState &result) {
1402   SMLoc loc;
1403   OpAsmParser::OperandType vector, value, position;
1404   Type vectorType, positionType;
1405   if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
1406       parser.parseComma() || parser.parseOperand(vector) ||
1407       parser.parseLSquare() || parser.parseOperand(position) ||
1408       parser.parseColonType(positionType) || parser.parseRSquare() ||
1409       parser.parseOptionalAttrDict(result.attributes) ||
1410       parser.parseColonType(vectorType))
1411     return failure();
1412 
1413   if (!LLVM::isCompatibleVectorType(vectorType))
1414     return parser.emitError(
1415         loc, "expected LLVM dialect-compatible vector type for operand #1");
1416   Type valueType = LLVM::getVectorElementType(vectorType);
1417   if (!valueType)
1418     return failure();
1419 
1420   if (parser.resolveOperand(vector, vectorType, result.operands) ||
1421       parser.resolveOperand(value, valueType, result.operands) ||
1422       parser.resolveOperand(position, positionType, result.operands))
1423     return failure();
1424 
1425   result.addTypes(vectorType);
1426   return success();
1427 }
1428 
1429 LogicalResult InsertElementOp::verify() {
1430   Type valueType = LLVM::getVectorElementType(getVector().getType());
1431   if (valueType != getValue().getType())
1432     return emitOpError() << "Type mismatch: cannot insert "
1433                          << getValue().getType() << " into "
1434                          << getVector().getType();
1435   return success();
1436 }
1437 
1438 //===----------------------------------------------------------------------===//
1439 // Printing/parsing for LLVM::InsertValueOp.
1440 //===----------------------------------------------------------------------===//
1441 
1442 void InsertValueOp::print(OpAsmPrinter &p) {
1443   p << ' ' << getValue() << ", " << getContainer() << getPosition();
1444   p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1445   p << " : " << getContainer().getType();
1446 }
1447 
1448 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
1449 //                 `[` integer-literal (`,` integer-literal)* `]`
1450 //                 attribute-dict? `:` type
1451 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) {
1452   OpAsmParser::OperandType container, value;
1453   Type containerType;
1454   ArrayAttr positionAttr;
1455   SMLoc attributeLoc, trailingTypeLoc;
1456 
1457   if (parser.parseOperand(value) || parser.parseComma() ||
1458       parser.parseOperand(container) ||
1459       parser.getCurrentLocation(&attributeLoc) ||
1460       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1461       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1462       parser.getCurrentLocation(&trailingTypeLoc) ||
1463       parser.parseType(containerType))
1464     return failure();
1465 
1466   auto valueType = getInsertExtractValueElementType(
1467       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1468   if (!valueType)
1469     return failure();
1470 
1471   if (parser.resolveOperand(container, containerType, result.operands) ||
1472       parser.resolveOperand(value, valueType, result.operands))
1473     return failure();
1474 
1475   result.addTypes(containerType);
1476   return success();
1477 }
1478 
1479 LogicalResult InsertValueOp::verify() {
1480   Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1481                                                     getPositionAttr(), *this);
1482   if (!valueType)
1483     return failure();
1484 
1485   if (getValue().getType() != valueType)
1486     return emitOpError() << "Type mismatch: cannot insert "
1487                          << getValue().getType() << " into "
1488                          << getContainer().getType();
1489 
1490   return success();
1491 }
1492 
1493 //===----------------------------------------------------------------------===//
1494 // Printing, parsing and verification for LLVM::ReturnOp.
1495 //===----------------------------------------------------------------------===//
1496 
1497 LogicalResult ReturnOp::verify() {
1498   if (getNumOperands() > 1)
1499     return emitOpError("expected at most 1 operand");
1500 
1501   if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
1502     Type expectedType = parent.getType().getReturnType();
1503     if (expectedType.isa<LLVMVoidType>()) {
1504       if (getNumOperands() == 0)
1505         return success();
1506       InFlightDiagnostic diag = emitOpError("expected no operands");
1507       diag.attachNote(parent->getLoc()) << "when returning from function";
1508       return diag;
1509     }
1510     if (getNumOperands() == 0) {
1511       if (expectedType.isa<LLVMVoidType>())
1512         return success();
1513       InFlightDiagnostic diag = emitOpError("expected 1 operand");
1514       diag.attachNote(parent->getLoc()) << "when returning from function";
1515       return diag;
1516     }
1517     if (expectedType != getOperand(0).getType()) {
1518       InFlightDiagnostic diag = emitOpError("mismatching result types");
1519       diag.attachNote(parent->getLoc()) << "when returning from function";
1520       return diag;
1521     }
1522   }
1523   return success();
1524 }
1525 
1526 //===----------------------------------------------------------------------===//
1527 // ResumeOp
1528 //===----------------------------------------------------------------------===//
1529 
1530 LogicalResult ResumeOp::verify() {
1531   if (!getValue().getDefiningOp<LandingpadOp>())
1532     return emitOpError("expects landingpad value as operand");
1533   // No check for personality of function - landingpad op verifies it.
1534   return success();
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // Verifier for LLVM::AddressOfOp.
1539 //===----------------------------------------------------------------------===//
1540 
1541 template <typename OpTy>
1542 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
1543   Operation *module = parent;
1544   while (module && !satisfiesLLVMModule(module))
1545     module = module->getParentOp();
1546   assert(module && "unexpected operation outside of a module");
1547   return dyn_cast_or_null<OpTy>(
1548       mlir::SymbolTable::lookupSymbolIn(module, name));
1549 }
1550 
1551 GlobalOp AddressOfOp::getGlobal() {
1552   return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
1553                                               getGlobalName());
1554 }
1555 
1556 LLVMFuncOp AddressOfOp::getFunction() {
1557   return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
1558                                                 getGlobalName());
1559 }
1560 
1561 LogicalResult AddressOfOp::verify() {
1562   auto global = getGlobal();
1563   auto function = getFunction();
1564   if (!global && !function)
1565     return emitOpError(
1566         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1567 
1568   if (global &&
1569       LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) !=
1570           getResult().getType())
1571     return emitOpError(
1572         "the type must be a pointer to the type of the referenced global");
1573 
1574   if (function &&
1575       LLVM::LLVMPointerType::get(function.getType()) != getResult().getType())
1576     return emitOpError(
1577         "the type must be a pointer to the type of the referenced function");
1578 
1579   return success();
1580 }
1581 
1582 //===----------------------------------------------------------------------===//
1583 // Builder, printer and verifier for LLVM::GlobalOp.
1584 //===----------------------------------------------------------------------===//
1585 
1586 /// Returns the name used for the linkage attribute. This *must* correspond to
1587 /// the name of the attribute in ODS.
1588 static StringRef getLinkageAttrName() { return "linkage"; }
1589 
1590 /// Returns the name used for the unnamed_addr attribute. This *must* correspond
1591 /// to the name of the attribute in ODS.
1592 static StringRef getUnnamedAddrAttrName() { return "unnamed_addr"; }
1593 
1594 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1595                      bool isConstant, Linkage linkage, StringRef name,
1596                      Attribute value, uint64_t alignment, unsigned addrSpace,
1597                      bool dsoLocal, ArrayRef<NamedAttribute> attrs) {
1598   result.addAttribute(SymbolTable::getSymbolAttrName(),
1599                       builder.getStringAttr(name));
1600   result.addAttribute("global_type", TypeAttr::get(type));
1601   if (isConstant)
1602     result.addAttribute("constant", builder.getUnitAttr());
1603   if (value)
1604     result.addAttribute("value", value);
1605   if (dsoLocal)
1606     result.addAttribute("dso_local", builder.getUnitAttr());
1607 
1608   // Only add an alignment attribute if the "alignment" input
1609   // is different from 0. The value must also be a power of two, but
1610   // this is tested in GlobalOp::verify, not here.
1611   if (alignment != 0)
1612     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
1613 
1614   result.addAttribute(::getLinkageAttrName(),
1615                       LinkageAttr::get(builder.getContext(), linkage));
1616   if (addrSpace != 0)
1617     result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
1618   result.attributes.append(attrs.begin(), attrs.end());
1619   result.addRegion();
1620 }
1621 
1622 void GlobalOp::print(OpAsmPrinter &p) {
1623   p << ' ' << stringifyLinkage(getLinkage()) << ' ';
1624   if (auto unnamedAddr = getUnnamedAddr()) {
1625     StringRef str = stringifyUnnamedAddr(*unnamedAddr);
1626     if (!str.empty())
1627       p << str << ' ';
1628   }
1629   if (getConstant())
1630     p << "constant ";
1631   p.printSymbolName(getSymName());
1632   p << '(';
1633   if (auto value = getValueOrNull())
1634     p.printAttribute(value);
1635   p << ')';
1636   // Note that the alignment attribute is printed using the
1637   // default syntax here, even though it is an inherent attribute
1638   // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1639   p.printOptionalAttrDict((*this)->getAttrs(),
1640                           {SymbolTable::getSymbolAttrName(), "global_type",
1641                            "constant", "value", getLinkageAttrName(),
1642                            getUnnamedAddrAttrName()});
1643 
1644   // Print the trailing type unless it's a string global.
1645   if (getValueOrNull().dyn_cast_or_null<StringAttr>())
1646     return;
1647   p << " : " << getType();
1648 
1649   Region &initializer = getInitializerRegion();
1650   if (!initializer.empty()) {
1651     p << ' ';
1652     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1653   }
1654 }
1655 
1656 // Parses one of the keywords provided in the list `keywords` and returns the
1657 // position of the parsed keyword in the list. If none of the keywords from the
1658 // list is parsed, returns -1.
1659 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1660                                            ArrayRef<StringRef> keywords) {
1661   for (const auto &en : llvm::enumerate(keywords)) {
1662     if (succeeded(parser.parseOptionalKeyword(en.value())))
1663       return en.index();
1664   }
1665   return -1;
1666 }
1667 
1668 namespace {
1669 template <typename Ty>
1670 struct EnumTraits {};
1671 
1672 #define REGISTER_ENUM_TYPE(Ty)                                                 \
1673   template <>                                                                  \
1674   struct EnumTraits<Ty> {                                                      \
1675     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
1676     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
1677   }
1678 
1679 REGISTER_ENUM_TYPE(Linkage);
1680 REGISTER_ENUM_TYPE(UnnamedAddr);
1681 } // namespace
1682 
1683 /// Parse an enum from the keyword, or default to the provided default value.
1684 /// The return type is the enum type by default, unless overriden with the
1685 /// second template argument.
1686 template <typename EnumTy, typename RetTy = EnumTy>
1687 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
1688                                       OperationState &result,
1689                                       EnumTy defaultValue) {
1690   SmallVector<StringRef, 10> names;
1691   for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
1692     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1693 
1694   int index = parseOptionalKeywordAlternative(parser, names);
1695   if (index == -1)
1696     return static_cast<RetTy>(defaultValue);
1697   return static_cast<RetTy>(index);
1698 }
1699 
1700 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1701 //               `(` attribute? `)` align? attribute-list? (`:` type)? region?
1702 // align     ::= `align` `=` UINT64
1703 //
1704 // The type can be omitted for string attributes, in which case it will be
1705 // inferred from the value of the string as [strlen(value) x i8].
1706 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
1707   MLIRContext *ctx = parser.getContext();
1708   // Parse optional linkage, default to External.
1709   result.addAttribute(::getLinkageAttrName(),
1710                       LLVM::LinkageAttr::get(
1711                           ctx, parseOptionalLLVMKeyword<Linkage>(
1712                                    parser, result, LLVM::Linkage::External)));
1713   // Parse optional UnnamedAddr, default to None.
1714   result.addAttribute(::getUnnamedAddrAttrName(),
1715                       parser.getBuilder().getI64IntegerAttr(
1716                           parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
1717                               parser, result, LLVM::UnnamedAddr::None)));
1718 
1719   if (succeeded(parser.parseOptionalKeyword("constant")))
1720     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1721 
1722   StringAttr name;
1723   if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1724                              result.attributes) ||
1725       parser.parseLParen())
1726     return failure();
1727 
1728   Attribute value;
1729   if (parser.parseOptionalRParen()) {
1730     if (parser.parseAttribute(value, "value", result.attributes) ||
1731         parser.parseRParen())
1732       return failure();
1733   }
1734 
1735   SmallVector<Type, 1> types;
1736   if (parser.parseOptionalAttrDict(result.attributes) ||
1737       parser.parseOptionalColonTypeList(types))
1738     return failure();
1739 
1740   if (types.size() > 1)
1741     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1742 
1743   Region &initRegion = *result.addRegion();
1744   if (types.empty()) {
1745     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1746       MLIRContext *context = parser.getContext();
1747       auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1748                                                 strAttr.getValue().size());
1749       types.push_back(arrayType);
1750     } else {
1751       return parser.emitError(parser.getNameLoc(),
1752                               "type can only be omitted for string globals");
1753     }
1754   } else {
1755     OptionalParseResult parseResult =
1756         parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1757                                    /*argTypes=*/{});
1758     if (parseResult.hasValue() && failed(*parseResult))
1759       return failure();
1760   }
1761 
1762   result.addAttribute("global_type", TypeAttr::get(types[0]));
1763   return success();
1764 }
1765 
1766 static bool isZeroAttribute(Attribute value) {
1767   if (auto intValue = value.dyn_cast<IntegerAttr>())
1768     return intValue.getValue().isNullValue();
1769   if (auto fpValue = value.dyn_cast<FloatAttr>())
1770     return fpValue.getValue().isZero();
1771   if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
1772     return isZeroAttribute(splatValue.getSplatValue<Attribute>());
1773   if (auto elementsValue = value.dyn_cast<ElementsAttr>())
1774     return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1775   if (auto arrayValue = value.dyn_cast<ArrayAttr>())
1776     return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1777   return false;
1778 }
1779 
1780 LogicalResult GlobalOp::verify() {
1781   if (!LLVMPointerType::isValidElementType(getType()))
1782     return emitOpError(
1783         "expects type to be a valid element type for an LLVM pointer");
1784   if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
1785     return emitOpError("must appear at the module level");
1786 
1787   if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1788     auto type = getType().dyn_cast<LLVMArrayType>();
1789     IntegerType elementType =
1790         type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
1791     if (!elementType || elementType.getWidth() != 8 ||
1792         type.getNumElements() != strAttr.getValue().size())
1793       return emitOpError(
1794           "requires an i8 array type of the length equal to that of the string "
1795           "attribute");
1796   }
1797 
1798   if (getLinkage() == Linkage::Common) {
1799     if (Attribute value = getValueOrNull()) {
1800       if (!isZeroAttribute(value)) {
1801         return emitOpError()
1802                << "expected zero value for '"
1803                << stringifyLinkage(Linkage::Common) << "' linkage";
1804       }
1805     }
1806   }
1807 
1808   if (getLinkage() == Linkage::Appending) {
1809     if (!getType().isa<LLVMArrayType>()) {
1810       return emitOpError() << "expected array type for '"
1811                            << stringifyLinkage(Linkage::Appending)
1812                            << "' linkage";
1813     }
1814   }
1815 
1816   Optional<uint64_t> alignAttr = getAlignment();
1817   if (alignAttr.hasValue()) {
1818     uint64_t value = alignAttr.getValue();
1819     if (!llvm::isPowerOf2_64(value))
1820       return emitError() << "alignment attribute is not a power of 2";
1821   }
1822 
1823   return success();
1824 }
1825 
1826 LogicalResult GlobalOp::verifyRegions() {
1827   if (Block *b = getInitializerBlock()) {
1828     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1829     if (ret.operand_type_begin() == ret.operand_type_end())
1830       return emitOpError("initializer region cannot return void");
1831     if (*ret.operand_type_begin() != getType())
1832       return emitOpError("initializer region type ")
1833              << *ret.operand_type_begin() << " does not match global type "
1834              << getType();
1835 
1836     for (Operation &op : *b) {
1837       auto iface = dyn_cast<MemoryEffectOpInterface>(op);
1838       if (!iface || !iface.hasNoEffect())
1839         return op.emitError()
1840                << "ops with side effects not allowed in global initializers";
1841     }
1842 
1843     if (getValueOrNull())
1844       return emitOpError("cannot have both initializer value and region");
1845   }
1846 
1847   return success();
1848 }
1849 
1850 //===----------------------------------------------------------------------===//
1851 // LLVM::GlobalCtorsOp
1852 //===----------------------------------------------------------------------===//
1853 
1854 LogicalResult
1855 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1856   for (Attribute ctor : getCtors()) {
1857     if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this,
1858                                    symbolTable)))
1859       return failure();
1860   }
1861   return success();
1862 }
1863 
1864 LogicalResult GlobalCtorsOp::verify() {
1865   if (getCtors().size() != getPriorities().size())
1866     return emitError(
1867         "mismatch between the number of ctors and the number of priorities");
1868   return success();
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // LLVM::GlobalDtorsOp
1873 //===----------------------------------------------------------------------===//
1874 
1875 LogicalResult
1876 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1877   for (Attribute dtor : getDtors()) {
1878     if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this,
1879                                    symbolTable)))
1880       return failure();
1881   }
1882   return success();
1883 }
1884 
1885 LogicalResult GlobalDtorsOp::verify() {
1886   if (getDtors().size() != getPriorities().size())
1887     return emitError(
1888         "mismatch between the number of dtors and the number of priorities");
1889   return success();
1890 }
1891 
1892 //===----------------------------------------------------------------------===//
1893 // Printing/parsing for LLVM::ShuffleVectorOp.
1894 //===----------------------------------------------------------------------===//
1895 // Expects vector to be of wrapped LLVM vector type and position to be of
1896 // wrapped LLVM i32 type.
1897 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
1898                                   Value v1, Value v2, ArrayAttr mask,
1899                                   ArrayRef<NamedAttribute> attrs) {
1900   auto containerType = v1.getType();
1901   auto vType = LLVM::getVectorType(
1902       LLVM::getVectorElementType(containerType), mask.size(),
1903       containerType.cast<VectorType>().isScalable());
1904   build(b, result, vType, v1, v2, mask);
1905   result.addAttributes(attrs);
1906 }
1907 
1908 void ShuffleVectorOp::print(OpAsmPrinter &p) {
1909   p << ' ' << getV1() << ", " << getV2() << " " << getMask();
1910   p.printOptionalAttrDict((*this)->getAttrs(), {"mask"});
1911   p << " : " << getV1().getType() << ", " << getV2().getType();
1912 }
1913 
1914 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1915 //                 `[` integer-literal (`,` integer-literal)* `]`
1916 //                 attribute-dict? `:` type
1917 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser,
1918                                    OperationState &result) {
1919   SMLoc loc;
1920   OpAsmParser::OperandType v1, v2;
1921   ArrayAttr maskAttr;
1922   Type typeV1, typeV2;
1923   if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1924       parser.parseComma() || parser.parseOperand(v2) ||
1925       parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1926       parser.parseOptionalAttrDict(result.attributes) ||
1927       parser.parseColonType(typeV1) || parser.parseComma() ||
1928       parser.parseType(typeV2) ||
1929       parser.resolveOperand(v1, typeV1, result.operands) ||
1930       parser.resolveOperand(v2, typeV2, result.operands))
1931     return failure();
1932   if (!LLVM::isCompatibleVectorType(typeV1))
1933     return parser.emitError(
1934         loc, "expected LLVM IR dialect vector type for operand #1");
1935   auto vType =
1936       LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(),
1937                           typeV1.cast<VectorType>().isScalable());
1938   result.addTypes(vType);
1939   return success();
1940 }
1941 
1942 LogicalResult ShuffleVectorOp::verify() {
1943   Type type1 = getV1().getType();
1944   Type type2 = getV2().getType();
1945   if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
1946     return emitOpError("expected matching LLVM IR Dialect element types");
1947   if (LLVM::isScalableVectorType(type1))
1948     if (llvm::any_of(getMask(), [](Attribute attr) {
1949           return attr.cast<IntegerAttr>().getInt() != 0;
1950         }))
1951       return emitOpError("expected a splat operation for scalable vectors");
1952   return success();
1953 }
1954 
1955 //===----------------------------------------------------------------------===//
1956 // Implementations for LLVM::LLVMFuncOp.
1957 //===----------------------------------------------------------------------===//
1958 
1959 // Add the entry block to the function.
1960 Block *LLVMFuncOp::addEntryBlock() {
1961   assert(empty() && "function already has an entry block");
1962   assert(!isVarArg() && "unimplemented: non-external variadic functions");
1963 
1964   auto *entry = new Block;
1965   push_back(entry);
1966 
1967   // FIXME: Allow passing in proper locations for the entry arguments.
1968   LLVMFunctionType type = getType();
1969   for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
1970     entry->addArgument(type.getParamType(i), getLoc());
1971   return entry;
1972 }
1973 
1974 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
1975                        StringRef name, Type type, LLVM::Linkage linkage,
1976                        bool dsoLocal, ArrayRef<NamedAttribute> attrs,
1977                        ArrayRef<DictionaryAttr> argAttrs) {
1978   result.addRegion();
1979   result.addAttribute(SymbolTable::getSymbolAttrName(),
1980                       builder.getStringAttr(name));
1981   result.addAttribute("type", TypeAttr::get(type));
1982   result.addAttribute(::getLinkageAttrName(),
1983                       LinkageAttr::get(builder.getContext(), linkage));
1984   result.attributes.append(attrs.begin(), attrs.end());
1985   if (dsoLocal)
1986     result.addAttribute("dso_local", builder.getUnitAttr());
1987   if (argAttrs.empty())
1988     return;
1989 
1990   assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
1991          "expected as many argument attribute lists as arguments");
1992   function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
1993                                                 /*resultAttrs=*/llvm::None);
1994 }
1995 
1996 // Builds an LLVM function type from the given lists of input and output types.
1997 // Returns a null type if any of the types provided are non-LLVM types, or if
1998 // there is more than one output type.
1999 static Type
2000 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2001                       ArrayRef<Type> outputs,
2002                       function_interface_impl::VariadicFlag variadicFlag) {
2003   Builder &b = parser.getBuilder();
2004   if (outputs.size() > 1) {
2005     parser.emitError(loc, "failed to construct function type: expected zero or "
2006                           "one function result");
2007     return {};
2008   }
2009 
2010   // Convert inputs to LLVM types, exit early on error.
2011   SmallVector<Type, 4> llvmInputs;
2012   for (auto t : inputs) {
2013     if (!isCompatibleType(t)) {
2014       parser.emitError(loc, "failed to construct function type: expected LLVM "
2015                             "type for function arguments");
2016       return {};
2017     }
2018     llvmInputs.push_back(t);
2019   }
2020 
2021   // No output is denoted as "void" in LLVM type system.
2022   Type llvmOutput =
2023       outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2024   if (!isCompatibleType(llvmOutput)) {
2025     parser.emitError(loc, "failed to construct function type: expected LLVM "
2026                           "type for function results")
2027         << llvmOutput;
2028     return {};
2029   }
2030   return LLVMFunctionType::get(llvmOutput, llvmInputs,
2031                                variadicFlag.isVariadic());
2032 }
2033 
2034 // Parses an LLVM function.
2035 //
2036 // operation ::= `llvm.func` linkage? function-signature function-attributes?
2037 //               function-body
2038 //
2039 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2040   // Default to external linkage if no keyword is provided.
2041   result.addAttribute(
2042       ::getLinkageAttrName(),
2043       LinkageAttr::get(parser.getContext(),
2044                        parseOptionalLLVMKeyword<Linkage>(
2045                            parser, result, LLVM::Linkage::External)));
2046 
2047   StringAttr nameAttr;
2048   SmallVector<OpAsmParser::OperandType> entryArgs;
2049   SmallVector<NamedAttrList> argAttrs;
2050   SmallVector<NamedAttrList> resultAttrs;
2051   SmallVector<Type> argTypes;
2052   SmallVector<Type> resultTypes;
2053   SmallVector<Location> argLocations;
2054   bool isVariadic;
2055 
2056   auto signatureLocation = parser.getCurrentLocation();
2057   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2058                              result.attributes) ||
2059       function_interface_impl::parseFunctionSignature(
2060           parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
2061           argLocations, isVariadic, resultTypes, resultAttrs))
2062     return failure();
2063 
2064   auto type =
2065       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2066                             function_interface_impl::VariadicFlag(isVariadic));
2067   if (!type)
2068     return failure();
2069   result.addAttribute(FunctionOpInterface::getTypeAttrName(),
2070                       TypeAttr::get(type));
2071 
2072   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2073     return failure();
2074   function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
2075                                                 argAttrs, resultAttrs);
2076 
2077   auto *body = result.addRegion();
2078   OptionalParseResult parseResult = parser.parseOptionalRegion(
2079       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
2080   return failure(parseResult.hasValue() && failed(*parseResult));
2081 }
2082 
2083 // Print the LLVMFuncOp. Collects argument and result types and passes them to
2084 // helper functions. Drops "void" result since it cannot be parsed back. Skips
2085 // the external linkage since it is the default value.
2086 void LLVMFuncOp::print(OpAsmPrinter &p) {
2087   p << ' ';
2088   if (getLinkage() != LLVM::Linkage::External)
2089     p << stringifyLinkage(getLinkage()) << ' ';
2090   p.printSymbolName(getName());
2091 
2092   LLVMFunctionType fnType = getType();
2093   SmallVector<Type, 8> argTypes;
2094   SmallVector<Type, 1> resTypes;
2095   argTypes.reserve(fnType.getNumParams());
2096   for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2097     argTypes.push_back(fnType.getParamType(i));
2098 
2099   Type returnType = fnType.getReturnType();
2100   if (!returnType.isa<LLVMVoidType>())
2101     resTypes.push_back(returnType);
2102 
2103   function_interface_impl::printFunctionSignature(p, *this, argTypes,
2104                                                   isVarArg(), resTypes);
2105   function_interface_impl::printFunctionAttributes(
2106       p, *this, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
2107 
2108   // Print the body if this is not an external function.
2109   Region &body = getBody();
2110   if (!body.empty()) {
2111     p << ' ';
2112     p.printRegion(body, /*printEntryBlockArgs=*/false,
2113                   /*printBlockTerminators=*/true);
2114   }
2115 }
2116 
2117 LogicalResult LLVMFuncOp::verifyType() {
2118   auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>();
2119   if (!llvmType)
2120     return emitOpError("requires '" + getTypeAttrName() +
2121                        "' attribute of wrapped LLVM function type");
2122 
2123   return success();
2124 }
2125 
2126 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2127 // - functions don't have 'common' linkage
2128 // - external functions have 'external' or 'extern_weak' linkage;
2129 // - vararg is (currently) only supported for external functions;
2130 LogicalResult LLVMFuncOp::verify() {
2131   if (getLinkage() == LLVM::Linkage::Common)
2132     return emitOpError() << "functions cannot have '"
2133                          << stringifyLinkage(LLVM::Linkage::Common)
2134                          << "' linkage";
2135 
2136   // Check to see if this function has a void return with a result attribute to
2137   // it. It isn't clear what semantics we would assign to that.
2138   if (getType().getReturnType().isa<LLVMVoidType>() &&
2139       !getResultAttrs(0).empty()) {
2140     return emitOpError()
2141            << "cannot attach result attributes to functions with a void return";
2142   }
2143 
2144   if (isExternal()) {
2145     if (getLinkage() != LLVM::Linkage::External &&
2146         getLinkage() != LLVM::Linkage::ExternWeak)
2147       return emitOpError() << "external functions must have '"
2148                            << stringifyLinkage(LLVM::Linkage::External)
2149                            << "' or '"
2150                            << stringifyLinkage(LLVM::Linkage::ExternWeak)
2151                            << "' linkage";
2152     return success();
2153   }
2154 
2155   if (isVarArg())
2156     return emitOpError("only external functions can be variadic");
2157 
2158   return success();
2159 }
2160 
2161 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2162 /// - entry block arguments are of LLVM types.
2163 LogicalResult LLVMFuncOp::verifyRegions() {
2164   if (isExternal())
2165     return success();
2166 
2167   unsigned numArguments = getType().getNumParams();
2168   Block &entryBlock = front();
2169   for (unsigned i = 0; i < numArguments; ++i) {
2170     Type argType = entryBlock.getArgument(i).getType();
2171     if (!isCompatibleType(argType))
2172       return emitOpError("entry block argument #")
2173              << i << " is not of LLVM type";
2174   }
2175 
2176   return success();
2177 }
2178 
2179 //===----------------------------------------------------------------------===//
2180 // Verification for LLVM::ConstantOp.
2181 //===----------------------------------------------------------------------===//
2182 
2183 LogicalResult LLVM::ConstantOp::verify() {
2184   if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
2185     auto arrayType = getType().dyn_cast<LLVMArrayType>();
2186     if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
2187         !arrayType.getElementType().isInteger(8)) {
2188       return emitOpError() << "expected array type of "
2189                            << sAttr.getValue().size()
2190                            << " i8 elements for the string constant";
2191     }
2192     return success();
2193   }
2194   if (auto structType = getType().dyn_cast<LLVMStructType>()) {
2195     if (structType.getBody().size() != 2 ||
2196         structType.getBody()[0] != structType.getBody()[1]) {
2197       return emitError() << "expected struct type with two elements of the "
2198                             "same type, the type of a complex constant";
2199     }
2200 
2201     auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
2202     if (!arrayAttr || arrayAttr.size() != 2 ||
2203         arrayAttr[0].getType() != arrayAttr[1].getType()) {
2204       return emitOpError() << "expected array attribute with two elements, "
2205                               "representing a complex constant";
2206     }
2207 
2208     Type elementType = structType.getBody()[0];
2209     if (!elementType
2210              .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
2211       return emitError()
2212              << "expected struct element types to be floating point type or "
2213                 "integer type";
2214     }
2215     return success();
2216   }
2217   if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
2218     return emitOpError()
2219            << "only supports integer, float, string or elements attributes";
2220   return success();
2221 }
2222 
2223 // Constant op constant-folds to its value.
2224 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); }
2225 
2226 //===----------------------------------------------------------------------===//
2227 // Utility functions for parsing atomic ops
2228 //===----------------------------------------------------------------------===//
2229 
2230 // Helper function to parse a keyword into the specified attribute named by
2231 // `attrName`. The keyword must match one of the string values defined by the
2232 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
2233 // state.
2234 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
2235                                     StringRef attrName) {
2236   SMLoc loc;
2237   StringRef keyword;
2238   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
2239     return failure();
2240 
2241   // Replace the keyword `keyword` with an integer attribute.
2242   auto kind = symbolizeAtomicBinOp(keyword);
2243   if (!kind) {
2244     return parser.emitError(loc)
2245            << "'" << keyword << "' is an incorrect value of the '" << attrName
2246            << "' attribute";
2247   }
2248 
2249   auto value = static_cast<int64_t>(kind.getValue());
2250   auto attr = parser.getBuilder().getI64IntegerAttr(value);
2251   result.addAttribute(attrName, attr);
2252 
2253   return success();
2254 }
2255 
2256 // Helper function to parse a keyword into the specified attribute named by
2257 // `attrName`. The keyword must match one of the string values defined by the
2258 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
2259 // state.
2260 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
2261                                        OperationState &result,
2262                                        StringRef attrName) {
2263   SMLoc loc;
2264   StringRef ordering;
2265   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
2266     return failure();
2267 
2268   // Replace the keyword `ordering` with an integer attribute.
2269   auto kind = symbolizeAtomicOrdering(ordering);
2270   if (!kind) {
2271     return parser.emitError(loc)
2272            << "'" << ordering << "' is an incorrect value of the '" << attrName
2273            << "' attribute";
2274   }
2275 
2276   auto value = static_cast<int64_t>(kind.getValue());
2277   auto attr = parser.getBuilder().getI64IntegerAttr(value);
2278   result.addAttribute(attrName, attr);
2279 
2280   return success();
2281 }
2282 
2283 //===----------------------------------------------------------------------===//
2284 // Printer, parser and verifier for LLVM::AtomicRMWOp.
2285 //===----------------------------------------------------------------------===//
2286 
2287 void AtomicRMWOp::print(OpAsmPrinter &p) {
2288   p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", "
2289     << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' ';
2290   p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"});
2291   p << " : " << getRes().getType();
2292 }
2293 
2294 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
2295 //                 attribute-dict? `:` type
2296 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) {
2297   Type type;
2298   OpAsmParser::OperandType ptr, val;
2299   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
2300       parser.parseComma() || parser.parseOperand(val) ||
2301       parseAtomicOrdering(parser, result, "ordering") ||
2302       parser.parseOptionalAttrDict(result.attributes) ||
2303       parser.parseColonType(type) ||
2304       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2305                             result.operands) ||
2306       parser.resolveOperand(val, type, result.operands))
2307     return failure();
2308 
2309   result.addTypes(type);
2310   return success();
2311 }
2312 
2313 LogicalResult AtomicRMWOp::verify() {
2314   auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2315   auto valType = getVal().getType();
2316   if (valType != ptrType.getElementType())
2317     return emitOpError("expected LLVM IR element type for operand #0 to "
2318                        "match type for operand #1");
2319   auto resType = getRes().getType();
2320   if (resType != valType)
2321     return emitOpError(
2322         "expected LLVM IR result type to match type for operand #1");
2323   if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) {
2324     if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2325       return emitOpError("expected LLVM IR floating point type");
2326   } else if (getBinOp() == AtomicBinOp::xchg) {
2327     auto intType = valType.dyn_cast<IntegerType>();
2328     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2329     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2330         intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
2331         !valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
2332         !valType.isa<Float64Type>())
2333       return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2334   } else {
2335     auto intType = valType.dyn_cast<IntegerType>();
2336     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2337     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2338         intBitWidth != 64)
2339       return emitOpError("expected LLVM IR integer type");
2340   }
2341 
2342   if (static_cast<unsigned>(getOrdering()) <
2343       static_cast<unsigned>(AtomicOrdering::monotonic))
2344     return emitOpError() << "expected at least '"
2345                          << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2346                          << "' ordering";
2347 
2348   return success();
2349 }
2350 
2351 //===----------------------------------------------------------------------===//
2352 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
2353 //===----------------------------------------------------------------------===//
2354 
2355 void AtomicCmpXchgOp::print(OpAsmPrinter &p) {
2356   p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' '
2357     << stringifyAtomicOrdering(getSuccessOrdering()) << ' '
2358     << stringifyAtomicOrdering(getFailureOrdering());
2359   p.printOptionalAttrDict((*this)->getAttrs(),
2360                           {"success_ordering", "failure_ordering"});
2361   p << " : " << getVal().getType();
2362 }
2363 
2364 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
2365 //                 keyword keyword attribute-dict? `:` type
2366 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser,
2367                                    OperationState &result) {
2368   auto &builder = parser.getBuilder();
2369   Type type;
2370   OpAsmParser::OperandType ptr, cmp, val;
2371   if (parser.parseOperand(ptr) || parser.parseComma() ||
2372       parser.parseOperand(cmp) || parser.parseComma() ||
2373       parser.parseOperand(val) ||
2374       parseAtomicOrdering(parser, result, "success_ordering") ||
2375       parseAtomicOrdering(parser, result, "failure_ordering") ||
2376       parser.parseOptionalAttrDict(result.attributes) ||
2377       parser.parseColonType(type) ||
2378       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2379                             result.operands) ||
2380       parser.resolveOperand(cmp, type, result.operands) ||
2381       parser.resolveOperand(val, type, result.operands))
2382     return failure();
2383 
2384   auto boolType = IntegerType::get(builder.getContext(), 1);
2385   auto resultType =
2386       LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
2387   result.addTypes(resultType);
2388 
2389   return success();
2390 }
2391 
2392 LogicalResult AtomicCmpXchgOp::verify() {
2393   auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2394   if (!ptrType)
2395     return emitOpError("expected LLVM IR pointer type for operand #0");
2396   auto cmpType = getCmp().getType();
2397   auto valType = getVal().getType();
2398   if (cmpType != ptrType.getElementType() || cmpType != valType)
2399     return emitOpError("expected LLVM IR element type for operand #0 to "
2400                        "match type for all other operands");
2401   auto intType = valType.dyn_cast<IntegerType>();
2402   unsigned intBitWidth = intType ? intType.getWidth() : 0;
2403   if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
2404       intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
2405       !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
2406       !valType.isa<Float32Type>() && !valType.isa<Float64Type>())
2407     return emitOpError("unexpected LLVM IR type");
2408   if (getSuccessOrdering() < AtomicOrdering::monotonic ||
2409       getFailureOrdering() < AtomicOrdering::monotonic)
2410     return emitOpError("ordering must be at least 'monotonic'");
2411   if (getFailureOrdering() == AtomicOrdering::release ||
2412       getFailureOrdering() == AtomicOrdering::acq_rel)
2413     return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2414   return success();
2415 }
2416 
2417 //===----------------------------------------------------------------------===//
2418 // Printer, parser and verifier for LLVM::FenceOp.
2419 //===----------------------------------------------------------------------===//
2420 
2421 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
2422 // attribute-dict?
2423 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) {
2424   StringAttr sScope;
2425   StringRef syncscopeKeyword = "syncscope";
2426   if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
2427     if (parser.parseLParen() ||
2428         parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
2429         parser.parseRParen())
2430       return failure();
2431   } else {
2432     result.addAttribute(syncscopeKeyword,
2433                         parser.getBuilder().getStringAttr(""));
2434   }
2435   if (parseAtomicOrdering(parser, result, "ordering") ||
2436       parser.parseOptionalAttrDict(result.attributes))
2437     return failure();
2438   return success();
2439 }
2440 
2441 void FenceOp::print(OpAsmPrinter &p) {
2442   StringRef syncscopeKeyword = "syncscope";
2443   p << ' ';
2444   if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
2445     p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") ";
2446   p << stringifyAtomicOrdering(getOrdering());
2447 }
2448 
2449 LogicalResult FenceOp::verify() {
2450   if (getOrdering() == AtomicOrdering::not_atomic ||
2451       getOrdering() == AtomicOrdering::unordered ||
2452       getOrdering() == AtomicOrdering::monotonic)
2453     return emitOpError("can be given only acquire, release, acq_rel, "
2454                        "and seq_cst orderings");
2455   return success();
2456 }
2457 
2458 //===----------------------------------------------------------------------===//
2459 // Folder for LLVM::BitcastOp
2460 //===----------------------------------------------------------------------===//
2461 
2462 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
2463   // bitcast(x : T0, T0) -> x
2464   if (getArg().getType() == getType())
2465     return getArg();
2466   // bitcast(bitcast(x : T0, T1), T0) -> x
2467   if (auto prev = getArg().getDefiningOp<BitcastOp>())
2468     if (prev.getArg().getType() == getType())
2469       return prev.getArg();
2470   return {};
2471 }
2472 
2473 //===----------------------------------------------------------------------===//
2474 // Folder for LLVM::AddrSpaceCastOp
2475 //===----------------------------------------------------------------------===//
2476 
2477 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
2478   // addrcast(x : T0, T0) -> x
2479   if (getArg().getType() == getType())
2480     return getArg();
2481   // addrcast(addrcast(x : T0, T1), T0) -> x
2482   if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
2483     if (prev.getArg().getType() == getType())
2484       return prev.getArg();
2485   return {};
2486 }
2487 
2488 //===----------------------------------------------------------------------===//
2489 // Folder for LLVM::GEPOp
2490 //===----------------------------------------------------------------------===//
2491 
2492 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
2493   // gep %x:T, 0 -> %x
2494   if (getBase().getType() == getType() && getIndices().size() == 1 &&
2495       matchPattern(getIndices()[0], m_Zero()))
2496     return getBase();
2497   return {};
2498 }
2499 
2500 //===----------------------------------------------------------------------===//
2501 // LLVMDialect initialization, type parsing, and registration.
2502 //===----------------------------------------------------------------------===//
2503 
2504 void LLVMDialect::initialize() {
2505   addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>();
2506 
2507   // clang-format off
2508   addTypes<LLVMVoidType,
2509            LLVMPPCFP128Type,
2510            LLVMX86MMXType,
2511            LLVMTokenType,
2512            LLVMLabelType,
2513            LLVMMetadataType,
2514            LLVMFunctionType,
2515            LLVMPointerType,
2516            LLVMFixedVectorType,
2517            LLVMScalableVectorType,
2518            LLVMArrayType,
2519            LLVMStructType>();
2520   // clang-format on
2521   addOperations<
2522 #define GET_OP_LIST
2523 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2524       >();
2525 
2526   // Support unknown operations because not all LLVM operations are registered.
2527   allowUnknownOperations();
2528 }
2529 
2530 #define GET_OP_CLASSES
2531 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2532 
2533 /// Parse a type registered to this dialect.
2534 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
2535   return detail::parseType(parser);
2536 }
2537 
2538 /// Print a type registered to this dialect.
2539 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
2540   return detail::printType(type, os);
2541 }
2542 
2543 LogicalResult LLVMDialect::verifyDataLayoutString(
2544     StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
2545   llvm::Expected<llvm::DataLayout> maybeDataLayout =
2546       llvm::DataLayout::parse(descr);
2547   if (maybeDataLayout)
2548     return success();
2549 
2550   std::string message;
2551   llvm::raw_string_ostream messageStream(message);
2552   llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
2553   reportError("invalid data layout descriptor: " + messageStream.str());
2554   return failure();
2555 }
2556 
2557 /// Verify LLVM dialect attributes.
2558 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
2559                                                     NamedAttribute attr) {
2560   // If the `llvm.loop` attribute is present, enforce the following structure,
2561   // which the module translation can assume.
2562   if (attr.getName() == LLVMDialect::getLoopAttrName()) {
2563     auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>();
2564     if (!loopAttr)
2565       return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
2566                                << "' to be a dictionary attribute";
2567     Optional<NamedAttribute> parallelAccessGroup =
2568         loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
2569     if (parallelAccessGroup.hasValue()) {
2570       auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>();
2571       if (!accessGroups)
2572         return op->emitOpError()
2573                << "expected '" << LLVMDialect::getParallelAccessAttrName()
2574                << "' to be an array attribute";
2575       for (Attribute attr : accessGroups) {
2576         auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
2577         if (!accessGroupRef)
2578           return op->emitOpError()
2579                  << "expected '" << attr << "' to be a symbol reference";
2580         StringAttr metadataName = accessGroupRef.getRootReference();
2581         auto metadataOp =
2582             SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
2583                 op->getParentOp(), metadataName);
2584         if (!metadataOp)
2585           return op->emitOpError()
2586                  << "expected '" << attr << "' to reference a metadata op";
2587         StringAttr accessGroupName = accessGroupRef.getLeafReference();
2588         Operation *accessGroupOp =
2589             SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
2590         if (!accessGroupOp)
2591           return op->emitOpError()
2592                  << "expected '" << attr << "' to reference an access_group op";
2593       }
2594     }
2595 
2596     Optional<NamedAttribute> loopOptions =
2597         loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
2598     if (loopOptions.hasValue() &&
2599         !loopOptions->getValue().isa<LoopOptionsAttr>())
2600       return op->emitOpError()
2601              << "expected '" << LLVMDialect::getLoopOptionsAttrName()
2602              << "' to be a `loopopts` attribute";
2603   }
2604 
2605   if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2606     return op->emitOpError()
2607            << "'" << LLVM::LLVMDialect::getStructAttrsAttrName()
2608            << "' is permitted only in argument or result attributes";
2609   }
2610 
2611   // If the data layout attribute is present, it must use the LLVM data layout
2612   // syntax. Try parsing it and report errors in case of failure. Users of this
2613   // attribute may assume it is well-formed and can pass it to the (asserting)
2614   // llvm::DataLayout constructor.
2615   if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
2616     return success();
2617   if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>())
2618     return verifyDataLayoutString(
2619         stringAttr.getValue(),
2620         [op](const Twine &message) { op->emitOpError() << message.str(); });
2621 
2622   return op->emitOpError() << "expected '"
2623                            << LLVM::LLVMDialect::getDataLayoutAttrName()
2624                            << "' to be a string attribute";
2625 }
2626 
2627 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr,
2628                                             Type annotatedType) {
2629   auto structType = annotatedType.dyn_cast<LLVMStructType>();
2630   if (!structType) {
2631     const auto emitIncorrectAnnotatedType = [&op]() {
2632       return op->emitError()
2633              << "expected '" << LLVMDialect::getStructAttrsAttrName()
2634              << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'";
2635     };
2636     const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>();
2637     if (!ptrType)
2638       return emitIncorrectAnnotatedType();
2639     structType = ptrType.getElementType().dyn_cast<LLVMStructType>();
2640     if (!structType)
2641       return emitIncorrectAnnotatedType();
2642   }
2643 
2644   const auto arrAttrs = attr.dyn_cast<ArrayAttr>();
2645   if (!arrAttrs)
2646     return op->emitError() << "expected '"
2647                            << LLVMDialect::getStructAttrsAttrName()
2648                            << "' to be an array attribute";
2649 
2650   if (structType.getBody().size() != arrAttrs.size())
2651     return op->emitError()
2652            << "size of '" << LLVMDialect::getStructAttrsAttrName()
2653            << "' must match the size of the annotated '!llvm.struct'";
2654   return success();
2655 }
2656 
2657 static LogicalResult verifyFuncOpInterfaceStructAttr(
2658     Operation *op, Attribute attr,
2659     std::function<Type(FunctionOpInterface)> getAnnotatedType) {
2660   if (auto funcOp = dyn_cast<FunctionOpInterface>(op))
2661     return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp));
2662   return op->emitError() << "expected '"
2663                          << LLVMDialect::getStructAttrsAttrName()
2664                          << "' to be used on function-like operations";
2665 }
2666 
2667 /// Verify LLVMIR function argument attributes.
2668 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
2669                                                     unsigned regionIdx,
2670                                                     unsigned argIdx,
2671                                                     NamedAttribute argAttr) {
2672   // Check that llvm.noalias is a unit attribute.
2673   if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() &&
2674       !argAttr.getValue().isa<UnitAttr>())
2675     return op->emitError()
2676            << "expected llvm.noalias argument attribute to be a unit attribute";
2677   // Check that llvm.align is an integer attribute.
2678   if (argAttr.getName() == LLVMDialect::getAlignAttrName() &&
2679       !argAttr.getValue().isa<IntegerAttr>())
2680     return op->emitError()
2681            << "llvm.align argument attribute of non integer type";
2682   if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2683     return verifyFuncOpInterfaceStructAttr(
2684         op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) {
2685           return funcOp.getArgumentTypes()[argIdx];
2686         });
2687   }
2688   return success();
2689 }
2690 
2691 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
2692                                                        unsigned regionIdx,
2693                                                        unsigned resIdx,
2694                                                        NamedAttribute resAttr) {
2695   if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2696     return verifyFuncOpInterfaceStructAttr(
2697         op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) {
2698           return funcOp.getResultTypes()[resIdx];
2699         });
2700   }
2701   return success();
2702 }
2703 
2704 //===----------------------------------------------------------------------===//
2705 // Utility functions.
2706 //===----------------------------------------------------------------------===//
2707 
2708 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
2709                                      StringRef name, StringRef value,
2710                                      LLVM::Linkage linkage) {
2711   assert(builder.getInsertionBlock() &&
2712          builder.getInsertionBlock()->getParentOp() &&
2713          "expected builder to point to a block constrained in an op");
2714   auto module =
2715       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
2716   assert(module && "builder points to an op outside of a module");
2717 
2718   // Create the global at the entry of the module.
2719   OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
2720   MLIRContext *ctx = builder.getContext();
2721   auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
2722   auto global = moduleBuilder.create<LLVM::GlobalOp>(
2723       loc, type, /*isConstant=*/true, linkage, name,
2724       builder.getStringAttr(value), /*alignment=*/0);
2725 
2726   // Get the pointer to the first character in the global string.
2727   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
2728   Value cst0 = builder.create<LLVM::ConstantOp>(
2729       loc, IntegerType::get(ctx, 64),
2730       builder.getIntegerAttr(builder.getIndexType(), 0));
2731   return builder.create<LLVM::GEPOp>(
2732       loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
2733       ValueRange{cst0, cst0});
2734 }
2735 
2736 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
2737   return op->hasTrait<OpTrait::SymbolTable>() &&
2738          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
2739 }
2740 
2741 static constexpr const FastmathFlags fastmathFlagsList[] = {
2742     // clang-format off
2743     FastmathFlags::nnan,
2744     FastmathFlags::ninf,
2745     FastmathFlags::nsz,
2746     FastmathFlags::arcp,
2747     FastmathFlags::contract,
2748     FastmathFlags::afn,
2749     FastmathFlags::reassoc,
2750     FastmathFlags::fast,
2751     // clang-format on
2752 };
2753 
2754 void FMFAttr::print(AsmPrinter &printer) const {
2755   printer << "<";
2756   auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) {
2757     return bitEnumContains(this->getFlags(), flag);
2758   });
2759   llvm::interleaveComma(flags, printer,
2760                         [&](auto flag) { printer << stringifyEnum(flag); });
2761   printer << ">";
2762 }
2763 
2764 Attribute FMFAttr::parse(AsmParser &parser, Type type) {
2765   if (failed(parser.parseLess()))
2766     return {};
2767 
2768   FastmathFlags flags = {};
2769   if (failed(parser.parseOptionalGreater())) {
2770     do {
2771       StringRef elemName;
2772       if (failed(parser.parseKeyword(&elemName)))
2773         return {};
2774 
2775       auto elem = symbolizeFastmathFlags(elemName);
2776       if (!elem) {
2777         parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
2778             << elemName;
2779         return {};
2780       }
2781 
2782       flags = flags | *elem;
2783     } while (succeeded(parser.parseOptionalComma()));
2784 
2785     if (failed(parser.parseGreater()))
2786       return {};
2787   }
2788 
2789   return FMFAttr::get(parser.getContext(), flags);
2790 }
2791 
2792 void LinkageAttr::print(AsmPrinter &printer) const {
2793   printer << "<";
2794   if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
2795     printer << stringifyEnum(getLinkage());
2796   else
2797     printer << static_cast<uint64_t>(getLinkage());
2798   printer << ">";
2799 }
2800 
2801 Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
2802   StringRef elemName;
2803   if (parser.parseLess() || parser.parseKeyword(&elemName) ||
2804       parser.parseGreater())
2805     return {};
2806   auto elem = linkage::symbolizeLinkage(elemName);
2807   if (!elem) {
2808     parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
2809     return {};
2810   }
2811   Linkage linkage = *elem;
2812   return LinkageAttr::get(parser.getContext(), linkage);
2813 }
2814 
2815 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
2816     : options(attr.getOptions().begin(), attr.getOptions().end()) {}
2817 
2818 template <typename T>
2819 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
2820                                                           Optional<T> value) {
2821   auto option = llvm::find_if(
2822       options, [tag](auto option) { return option.first == tag; });
2823   if (option != options.end()) {
2824     if (value.hasValue())
2825       option->second = *value;
2826     else
2827       options.erase(option);
2828   } else {
2829     options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
2830   }
2831   return *this;
2832 }
2833 
2834 LoopOptionsAttrBuilder &
2835 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
2836   return setOption(LoopOptionCase::disable_licm, value);
2837 }
2838 
2839 /// Set the `interleave_count` option to the provided value. If no value
2840 /// is provided the option is deleted.
2841 LoopOptionsAttrBuilder &
2842 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
2843   return setOption(LoopOptionCase::interleave_count, count);
2844 }
2845 
2846 /// Set the `disable_unroll` option to the provided value. If no value
2847 /// is provided the option is deleted.
2848 LoopOptionsAttrBuilder &
2849 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
2850   return setOption(LoopOptionCase::disable_unroll, value);
2851 }
2852 
2853 /// Set the `disable_pipeline` option to the provided value. If no value
2854 /// is provided the option is deleted.
2855 LoopOptionsAttrBuilder &
2856 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
2857   return setOption(LoopOptionCase::disable_pipeline, value);
2858 }
2859 
2860 /// Set the `pipeline_initiation_interval` option to the provided value.
2861 /// If no value is provided the option is deleted.
2862 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
2863     Optional<uint64_t> count) {
2864   return setOption(LoopOptionCase::pipeline_initiation_interval, count);
2865 }
2866 
2867 template <typename T>
2868 static Optional<T>
2869 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
2870           LoopOptionCase option) {
2871   auto it =
2872       lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
2873         return optionPair.first < option;
2874       });
2875   if (it == options.end())
2876     return {};
2877   return static_cast<T>(it->second);
2878 }
2879 
2880 Optional<bool> LoopOptionsAttr::disableUnroll() {
2881   return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
2882 }
2883 
2884 Optional<bool> LoopOptionsAttr::disableLICM() {
2885   return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
2886 }
2887 
2888 Optional<int64_t> LoopOptionsAttr::interleaveCount() {
2889   return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
2890 }
2891 
2892 /// Build the LoopOptions Attribute from a sorted array of individual options.
2893 LoopOptionsAttr LoopOptionsAttr::get(
2894     MLIRContext *context,
2895     ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
2896   assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
2897          "LoopOptionsAttr ctor expects a sorted options array");
2898   return Base::get(context, sortedOptions);
2899 }
2900 
2901 /// Build the LoopOptions Attribute from a sorted array of individual options.
2902 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
2903                                      LoopOptionsAttrBuilder &optionBuilders) {
2904   llvm::sort(optionBuilders.options, llvm::less_first());
2905   return Base::get(context, optionBuilders.options);
2906 }
2907 
2908 void LoopOptionsAttr::print(AsmPrinter &printer) const {
2909   printer << "<";
2910   llvm::interleaveComma(getOptions(), printer, [&](auto option) {
2911     printer << stringifyEnum(option.first) << " = ";
2912     switch (option.first) {
2913     case LoopOptionCase::disable_licm:
2914     case LoopOptionCase::disable_unroll:
2915     case LoopOptionCase::disable_pipeline:
2916       printer << (option.second ? "true" : "false");
2917       break;
2918     case LoopOptionCase::interleave_count:
2919     case LoopOptionCase::pipeline_initiation_interval:
2920       printer << option.second;
2921       break;
2922     }
2923   });
2924   printer << ">";
2925 }
2926 
2927 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
2928   if (failed(parser.parseLess()))
2929     return {};
2930 
2931   SmallVector<std::pair<LoopOptionCase, int64_t>> options;
2932   llvm::SmallDenseSet<LoopOptionCase> seenOptions;
2933   do {
2934     StringRef optionName;
2935     if (parser.parseKeyword(&optionName))
2936       return {};
2937 
2938     auto option = symbolizeLoopOptionCase(optionName);
2939     if (!option) {
2940       parser.emitError(parser.getNameLoc(), "unknown loop option: ")
2941           << optionName;
2942       return {};
2943     }
2944     if (!seenOptions.insert(*option).second) {
2945       parser.emitError(parser.getNameLoc(), "loop option present twice");
2946       return {};
2947     }
2948     if (failed(parser.parseEqual()))
2949       return {};
2950 
2951     int64_t value;
2952     switch (*option) {
2953     case LoopOptionCase::disable_licm:
2954     case LoopOptionCase::disable_unroll:
2955     case LoopOptionCase::disable_pipeline:
2956       if (succeeded(parser.parseOptionalKeyword("true")))
2957         value = 1;
2958       else if (succeeded(parser.parseOptionalKeyword("false")))
2959         value = 0;
2960       else {
2961         parser.emitError(parser.getNameLoc(),
2962                          "expected boolean value 'true' or 'false'");
2963         return {};
2964       }
2965       break;
2966     case LoopOptionCase::interleave_count:
2967     case LoopOptionCase::pipeline_initiation_interval:
2968       if (failed(parser.parseInteger(value))) {
2969         parser.emitError(parser.getNameLoc(), "expected integer value");
2970         return {};
2971       }
2972       break;
2973     }
2974     options.push_back(std::make_pair(*option, value));
2975   } while (succeeded(parser.parseOptionalComma()));
2976   if (failed(parser.parseGreater()))
2977     return {};
2978 
2979   llvm::sort(options, llvm::less_first());
2980   return get(parser.getContext(), options);
2981 }
2982