1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/MLIRContext.h"
21 
22 #include "llvm/ADT/StringSwitch.h"
23 #include "llvm/AsmParser/Parser.h"
24 #include "llvm/Bitcode/BitcodeReader.h"
25 #include "llvm/Bitcode/BitcodeWriter.h"
26 #include "llvm/IR/Attributes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/Support/Mutex.h"
30 #include "llvm/Support/SourceMgr.h"
31 
32 using namespace mlir;
33 using namespace mlir::LLVM;
34 
35 static constexpr const char kVolatileAttrName[] = "volatile_";
36 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
37 
38 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
39 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
40 
41 namespace mlir {
42 namespace LLVM {
43 namespace detail {
44 struct BitmaskEnumStorage : public AttributeStorage {
45   using KeyTy = uint64_t;
46 
47   BitmaskEnumStorage(KeyTy val) : value(val) {}
48 
49   bool operator==(const KeyTy &key) const { return value == key; }
50 
51   static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
52                                        const KeyTy &key) {
53     return new (allocator.allocate<BitmaskEnumStorage>())
54         BitmaskEnumStorage(key);
55   }
56 
57   KeyTy value = 0;
58 };
59 
60 struct LoopOptionAttrStorage : public AttributeStorage {
61   using KeyTy = std::pair<uint64_t, int32_t>;
62 
63   explicit LoopOptionAttrStorage(uint64_t option, int32_t value)
64       : option(option), value(value) {}
65 
66   bool operator==(const KeyTy &key) const {
67     return key == KeyTy(option, value);
68   }
69 
70   static LoopOptionAttrStorage *
71   construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
72     return new (allocator.allocate<LoopOptionAttrStorage>())
73         LoopOptionAttrStorage(key.first, key.second);
74   }
75 
76   uint64_t option;
77   int32_t value;
78 };
79 } // namespace detail
80 } // namespace LLVM
81 } // namespace mlir
82 
83 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
84   SmallVector<NamedAttribute, 8> filteredAttrs(
85       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
86         if (attr.first == "fastmathFlags") {
87           auto defAttr = FMFAttr::get({}, attr.second.getContext());
88           return defAttr != attr.second;
89         }
90         return true;
91       }));
92   return filteredAttrs;
93 }
94 
95 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
96                                     NamedAttrList &result) {
97   return parser.parseOptionalAttrDict(result);
98 }
99 
100 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
101                              DictionaryAttr attrs) {
102   printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Printing/parsing for LLVM::CmpOp.
107 //===----------------------------------------------------------------------===//
108 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
109   p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
110     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
111   p.printOptionalAttrDict(op->getAttrs(), {"predicate"});
112   p << " : " << op.lhs().getType();
113 }
114 
115 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
116   p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
117     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
118   p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"predicate"});
119   p << " : " << op.lhs().getType();
120 }
121 
122 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
123 //                 attribute-dict? `:` type
124 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
125 //                 attribute-dict? `:` type
126 template <typename CmpPredicateType>
127 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
128   Builder &builder = parser.getBuilder();
129 
130   StringAttr predicateAttr;
131   OpAsmParser::OperandType lhs, rhs;
132   Type type;
133   llvm::SMLoc predicateLoc, trailingTypeLoc;
134   if (parser.getCurrentLocation(&predicateLoc) ||
135       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
136       parser.parseOperand(lhs) || parser.parseComma() ||
137       parser.parseOperand(rhs) ||
138       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
139       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
140       parser.resolveOperand(lhs, type, result.operands) ||
141       parser.resolveOperand(rhs, type, result.operands))
142     return failure();
143 
144   // Replace the string attribute `predicate` with an integer attribute.
145   int64_t predicateValue = 0;
146   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
147     Optional<ICmpPredicate> predicate =
148         symbolizeICmpPredicate(predicateAttr.getValue());
149     if (!predicate)
150       return parser.emitError(predicateLoc)
151              << "'" << predicateAttr.getValue()
152              << "' is an incorrect value of the 'predicate' attribute";
153     predicateValue = static_cast<int64_t>(predicate.getValue());
154   } else {
155     Optional<FCmpPredicate> predicate =
156         symbolizeFCmpPredicate(predicateAttr.getValue());
157     if (!predicate)
158       return parser.emitError(predicateLoc)
159              << "'" << predicateAttr.getValue()
160              << "' is an incorrect value of the 'predicate' attribute";
161     predicateValue = static_cast<int64_t>(predicate.getValue());
162   }
163 
164   result.attributes.set("predicate",
165                         parser.getBuilder().getI64IntegerAttr(predicateValue));
166 
167   // The result type is either i1 or a vector type <? x i1> if the inputs are
168   // vectors.
169   Type resultType = IntegerType::get(builder.getContext(), 1);
170   if (!isCompatibleType(type))
171     return parser.emitError(trailingTypeLoc,
172                             "expected LLVM dialect-compatible type");
173   if (LLVM::isCompatibleVectorType(type))
174     resultType = LLVM::getFixedVectorType(
175         resultType, LLVM::getVectorNumElements(type).getFixedValue());
176   assert(!type.isa<LLVM::LLVMScalableVectorType>() &&
177          "unhandled scalable vector");
178 
179   result.addTypes({resultType});
180   return success();
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Printing/parsing for LLVM::AllocaOp.
185 //===----------------------------------------------------------------------===//
186 
187 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
188   auto elemTy = op.getType().cast<LLVM::LLVMPointerType>().getElementType();
189 
190   auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
191                                   {op.getType()});
192 
193   p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
194   if (op.alignment().hasValue() && *op.alignment() != 0)
195     p.printOptionalAttrDict(op->getAttrs());
196   else
197     p.printOptionalAttrDict(op->getAttrs(), {"alignment"});
198   p << " : " << funcTy;
199 }
200 
201 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
202 //                 `:` type `,` type
203 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
204   OpAsmParser::OperandType arraySize;
205   Type type, elemType;
206   llvm::SMLoc trailingTypeLoc;
207   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
208       parser.parseType(elemType) ||
209       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
210       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
211     return failure();
212 
213   Optional<NamedAttribute> alignmentAttr =
214       result.attributes.getNamed("alignment");
215   if (alignmentAttr.hasValue()) {
216     auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>();
217     if (!alignmentInt)
218       return parser.emitError(parser.getNameLoc(),
219                               "expected integer alignment");
220     if (alignmentInt.getValue().isNullValue())
221       result.attributes.erase("alignment");
222   }
223 
224   // Extract the result type from the trailing function type.
225   auto funcType = type.dyn_cast<FunctionType>();
226   if (!funcType || funcType.getNumInputs() != 1 ||
227       funcType.getNumResults() != 1)
228     return parser.emitError(
229         trailingTypeLoc,
230         "expected trailing function type with one argument and one result");
231 
232   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
233     return failure();
234 
235   result.addTypes({funcType.getResult(0)});
236   return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // LLVM::BrOp
241 //===----------------------------------------------------------------------===//
242 
243 Optional<MutableOperandRange>
244 BrOp::getMutableSuccessorOperands(unsigned index) {
245   assert(index == 0 && "invalid successor index");
246   return destOperandsMutable();
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // LLVM::CondBrOp
251 //===----------------------------------------------------------------------===//
252 
253 Optional<MutableOperandRange>
254 CondBrOp::getMutableSuccessorOperands(unsigned index) {
255   assert(index < getNumSuccessors() && "invalid successor index");
256   return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // LLVM::SwitchOp
261 //===----------------------------------------------------------------------===//
262 
263 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
264                      Block *defaultDestination, ValueRange defaultOperands,
265                      ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
266                      ArrayRef<ValueRange> caseOperands,
267                      ArrayRef<int32_t> branchWeights) {
268   SmallVector<Value> flattenedCaseOperands;
269   SmallVector<int32_t> caseOperandOffsets;
270   int32_t offset = 0;
271   for (ValueRange operands : caseOperands) {
272     flattenedCaseOperands.append(operands.begin(), operands.end());
273     caseOperandOffsets.push_back(offset);
274     offset += operands.size();
275   }
276   ElementsAttr caseValuesAttr;
277   if (!caseValues.empty())
278     caseValuesAttr = builder.getI32VectorAttr(caseValues);
279   ElementsAttr caseOperandOffsetsAttr;
280   if (!caseOperandOffsets.empty())
281     caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
282 
283   ElementsAttr weightsAttr;
284   if (!branchWeights.empty())
285     weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
286 
287   build(builder, result, value, defaultOperands, flattenedCaseOperands,
288         caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination,
289         caseDestinations);
290 }
291 
292 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
293 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
294 static ParseResult
295 parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
296                    SmallVectorImpl<Block *> &caseDestinations,
297                    SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
298                    SmallVectorImpl<Type> &caseOperandTypes,
299                    ElementsAttr &caseOperandOffsets) {
300   SmallVector<int32_t> values;
301   SmallVector<int32_t> offsets;
302   int32_t value, offset = 0;
303   do {
304     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
305     if (values.empty() && !integerParseResult.hasValue())
306       return success();
307 
308     if (!integerParseResult.hasValue() || integerParseResult.getValue())
309       return failure();
310     values.push_back(value);
311 
312     Block *destination;
313     SmallVector<OpAsmParser::OperandType> operands;
314     if (parser.parseColon() || parser.parseSuccessor(destination))
315       return failure();
316     if (!parser.parseOptionalLParen()) {
317       if (parser.parseRegionArgumentList(operands) ||
318           parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen())
319         return failure();
320     }
321     caseDestinations.push_back(destination);
322     caseOperands.append(operands.begin(), operands.end());
323     offsets.push_back(offset);
324     offset += operands.size();
325   } while (!parser.parseOptionalComma());
326 
327   Builder &builder = parser.getBuilder();
328   caseValues = builder.getI32VectorAttr(values);
329   caseOperandOffsets = builder.getI32VectorAttr(offsets);
330 
331   return success();
332 }
333 
334 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
335                                ElementsAttr caseValues,
336                                SuccessorRange caseDestinations,
337                                OperandRange caseOperands,
338                                TypeRange caseOperandTypes,
339                                ElementsAttr caseOperandOffsets) {
340   if (!caseValues)
341     return;
342 
343   size_t index = 0;
344   llvm::interleave(
345       llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
346       [&](auto i) {
347         p << "  ";
348         p << std::get<0>(i).getLimitedValue();
349         p << ": ";
350         p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++));
351       },
352       [&] {
353         p << ',';
354         p.printNewline();
355       });
356   p.printNewline();
357 }
358 
359 static LogicalResult verify(SwitchOp op) {
360   if ((!op.case_values() && !op.caseDestinations().empty()) ||
361       (op.case_values() &&
362        op.case_values()->size() !=
363            static_cast<int64_t>(op.caseDestinations().size())))
364     return op.emitOpError("expects number of case values to match number of "
365                           "case destinations");
366   if (op.branch_weights() &&
367       op.branch_weights()->size() != op.getNumSuccessors())
368     return op.emitError("expects number of branch weights to match number of "
369                         "successors: ")
370            << op.branch_weights()->size() << " vs " << op.getNumSuccessors();
371   return success();
372 }
373 
374 OperandRange SwitchOp::getCaseOperands(unsigned index) {
375   return getCaseOperandsMutable(index);
376 }
377 
378 MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
379   MutableOperandRange caseOperands = caseOperandsMutable();
380   if (!case_operand_offsets()) {
381     assert(caseOperands.size() == 0 &&
382            "non-empty case operands must have offsets");
383     return caseOperands;
384   }
385 
386   ElementsAttr offsets = case_operand_offsets().getValue();
387   assert(index < offsets.size() && "invalid case operand offset index");
388 
389   int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
390   int64_t end = index + 1 == offsets.size()
391                     ? caseOperands.size()
392                     : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
393   return caseOperandsMutable().slice(begin, end - begin);
394 }
395 
396 Optional<MutableOperandRange>
397 SwitchOp::getMutableSuccessorOperands(unsigned index) {
398   assert(index < getNumSuccessors() && "invalid successor index");
399   return index == 0 ? defaultOperandsMutable()
400                     : getCaseOperandsMutable(index - 1);
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // Builder, printer and parser for for LLVM::LoadOp.
405 //===----------------------------------------------------------------------===//
406 
407 static LogicalResult verifyAccessGroups(Operation *op) {
408   if (Attribute attribute =
409           op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
410     // The attribute is already verified to be a symbol ref array attribute via
411     // a constraint in the operation definition.
412     for (SymbolRefAttr accessGroupRef :
413          attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
414       StringRef metadataName = accessGroupRef.getRootReference();
415       auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
416           op->getParentOp(), metadataName);
417       if (!metadataOp)
418         return op->emitOpError() << "expected '" << accessGroupRef
419                                  << "' to reference a metadata op";
420       StringRef accessGroupName = accessGroupRef.getLeafReference();
421       Operation *accessGroupOp =
422           SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
423       if (!accessGroupOp)
424         return op->emitOpError() << "expected '" << accessGroupRef
425                                  << "' to reference an access_group op";
426     }
427   }
428   return success();
429 }
430 
431 static LogicalResult verify(LoadOp op) {
432   return verifyAccessGroups(op.getOperation());
433 }
434 
435 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
436                    Value addr, unsigned alignment, bool isVolatile,
437                    bool isNonTemporal) {
438   result.addOperands(addr);
439   result.addTypes(t);
440   if (isVolatile)
441     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
442   if (isNonTemporal)
443     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
444   if (alignment != 0)
445     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
446 }
447 
448 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
449   p << op.getOperationName() << ' ';
450   if (op.volatile_())
451     p << "volatile ";
452   p << op.addr();
453   p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName});
454   p << " : " << op.addr().getType();
455 }
456 
457 // Extract the pointee type from the LLVM pointer type wrapped in MLIR.  Return
458 // the resulting type wrapped in MLIR, or nullptr on error.
459 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
460                                     llvm::SMLoc trailingTypeLoc) {
461   auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
462   if (!llvmTy)
463     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
464            nullptr;
465   return llvmTy.getElementType();
466 }
467 
468 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
469 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
470   OpAsmParser::OperandType addr;
471   Type type;
472   llvm::SMLoc trailingTypeLoc;
473 
474   if (succeeded(parser.parseOptionalKeyword("volatile")))
475     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
476 
477   if (parser.parseOperand(addr) ||
478       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
479       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
480       parser.resolveOperand(addr, type, result.operands))
481     return failure();
482 
483   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
484 
485   result.addTypes(elemTy);
486   return success();
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // Builder, printer and parser for LLVM::StoreOp.
491 //===----------------------------------------------------------------------===//
492 
493 static LogicalResult verify(StoreOp op) {
494   return verifyAccessGroups(op.getOperation());
495 }
496 
497 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
498                     Value addr, unsigned alignment, bool isVolatile,
499                     bool isNonTemporal) {
500   result.addOperands({value, addr});
501   result.addTypes({});
502   if (isVolatile)
503     result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
504   if (isNonTemporal)
505     result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
506   if (alignment != 0)
507     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
508 }
509 
510 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
511   p << op.getOperationName() << ' ';
512   if (op.volatile_())
513     p << "volatile ";
514   p << op.value() << ", " << op.addr();
515   p.printOptionalAttrDict(op->getAttrs(), {kVolatileAttrName});
516   p << " : " << op.addr().getType();
517 }
518 
519 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
520 //                 attribute-dict? `:` type
521 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
522   OpAsmParser::OperandType addr, value;
523   Type type;
524   llvm::SMLoc trailingTypeLoc;
525 
526   if (succeeded(parser.parseOptionalKeyword("volatile")))
527     result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
528 
529   if (parser.parseOperand(value) || parser.parseComma() ||
530       parser.parseOperand(addr) ||
531       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
532       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
533     return failure();
534 
535   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
536   if (!elemTy)
537     return failure();
538 
539   if (parser.resolveOperand(value, elemTy, result.operands) ||
540       parser.resolveOperand(addr, type, result.operands))
541     return failure();
542 
543   return success();
544 }
545 
546 ///===---------------------------------------------------------------------===//
547 /// LLVM::InvokeOp
548 ///===---------------------------------------------------------------------===//
549 
550 Optional<MutableOperandRange>
551 InvokeOp::getMutableSuccessorOperands(unsigned index) {
552   assert(index < getNumSuccessors() && "invalid successor index");
553   return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
554 }
555 
556 static LogicalResult verify(InvokeOp op) {
557   if (op.getNumResults() > 1)
558     return op.emitOpError("must have 0 or 1 result");
559 
560   Block *unwindDest = op.unwindDest();
561   if (unwindDest->empty())
562     return op.emitError(
563         "must have at least one operation in unwind destination");
564 
565   // In unwind destination, first operation must be LandingpadOp
566   if (!isa<LandingpadOp>(unwindDest->front()))
567     return op.emitError("first operation in unwind destination should be a "
568                         "llvm.landingpad operation");
569 
570   return success();
571 }
572 
573 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) {
574   auto callee = op.callee();
575   bool isDirect = callee.hasValue();
576 
577   p << op.getOperationName() << ' ';
578 
579   // Either function name or pointer
580   if (isDirect)
581     p.printSymbolName(callee.getValue());
582   else
583     p << op.getOperand(0);
584 
585   p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
586   p << " to ";
587   p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands());
588   p << " unwind ";
589   p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands());
590 
591   p.printOptionalAttrDict(op->getAttrs(),
592                           {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
593   p << " : ";
594   p.printFunctionalType(
595       llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1),
596       op.getResultTypes());
597 }
598 
599 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
600 ///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
601 ///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
602 ///                  attribute-dict? `:` function-type
603 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
604   SmallVector<OpAsmParser::OperandType, 8> operands;
605   FunctionType funcType;
606   SymbolRefAttr funcAttr;
607   llvm::SMLoc trailingTypeLoc;
608   Block *normalDest, *unwindDest;
609   SmallVector<Value, 4> normalOperands, unwindOperands;
610   Builder &builder = parser.getBuilder();
611 
612   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
613   // case of an indirect call, there will be 1 operand before `(`.  In case of a
614   // direct call, there will be no operands and the parser will stop at the
615   // function identifier without complaining.
616   if (parser.parseOperandList(operands))
617     return failure();
618   bool isDirect = operands.empty();
619 
620   // Optionally parse a function identifier.
621   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
622     return failure();
623 
624   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
625       parser.parseKeyword("to") ||
626       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
627       parser.parseKeyword("unwind") ||
628       parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
629       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
630       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
631     return failure();
632 
633   if (isDirect) {
634     // Make sure types match.
635     if (parser.resolveOperands(operands, funcType.getInputs(),
636                                parser.getNameLoc(), result.operands))
637       return failure();
638     result.addTypes(funcType.getResults());
639   } else {
640     // Construct the LLVM IR Dialect function type that the first operand
641     // should match.
642     if (funcType.getNumResults() > 1)
643       return parser.emitError(trailingTypeLoc,
644                               "expected function with 0 or 1 result");
645 
646     Type llvmResultType;
647     if (funcType.getNumResults() == 0) {
648       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
649     } else {
650       llvmResultType = funcType.getResult(0);
651       if (!isCompatibleType(llvmResultType))
652         return parser.emitError(trailingTypeLoc,
653                                 "expected result to have LLVM type");
654     }
655 
656     SmallVector<Type, 8> argTypes;
657     argTypes.reserve(funcType.getNumInputs());
658     for (Type ty : funcType.getInputs()) {
659       if (isCompatibleType(ty))
660         argTypes.push_back(ty);
661       else
662         return parser.emitError(trailingTypeLoc,
663                                 "expected LLVM types as inputs");
664     }
665 
666     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
667     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
668 
669     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
670 
671     // Make sure that the first operand (indirect callee) matches the wrapped
672     // LLVM IR function type, and that the types of the other call operands
673     // match the types of the function arguments.
674     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
675         parser.resolveOperands(funcArguments, funcType.getInputs(),
676                                parser.getNameLoc(), result.operands))
677       return failure();
678 
679     result.addTypes(llvmResultType);
680   }
681   result.addSuccessors({normalDest, unwindDest});
682   result.addOperands(normalOperands);
683   result.addOperands(unwindOperands);
684 
685   result.addAttribute(
686       InvokeOp::getOperandSegmentSizeAttr(),
687       builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
688                                 static_cast<int32_t>(normalOperands.size()),
689                                 static_cast<int32_t>(unwindOperands.size())}));
690   return success();
691 }
692 
693 ///===----------------------------------------------------------------------===//
694 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
695 ///===----------------------------------------------------------------------===//
696 
697 static LogicalResult verify(LandingpadOp op) {
698   Value value;
699   if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
700     if (!func.personality().hasValue())
701       return op.emitError(
702           "llvm.landingpad needs to be in a function with a personality");
703   }
704 
705   if (!op.cleanup() && op.getOperands().empty())
706     return op.emitError("landingpad instruction expects at least one clause or "
707                         "cleanup attribute");
708 
709   for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
710     value = op.getOperand(idx);
711     bool isFilter = value.getType().isa<LLVMArrayType>();
712     if (isFilter) {
713       // FIXME: Verify filter clauses when arrays are appropriately handled
714     } else {
715       // catch - global addresses only.
716       // Bitcast ops should have global addresses as their args.
717       if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
718         if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
719           continue;
720         return op.emitError("constant clauses expected")
721                    .attachNote(bcOp.getLoc())
722                << "global addresses expected as operand to "
723                   "bitcast used in clauses for landingpad";
724       }
725       // NullOp and AddressOfOp allowed
726       if (value.getDefiningOp<NullOp>())
727         continue;
728       if (value.getDefiningOp<AddressOfOp>())
729         continue;
730       return op.emitError("clause #")
731              << idx << " is not a known constant - null, addressof, bitcast";
732     }
733   }
734   return success();
735 }
736 
737 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
738   p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");
739 
740   // Clauses
741   for (auto value : op.getOperands()) {
742     // Similar to llvm - if clause is an array type then it is filter
743     // clause else catch clause
744     bool isArrayTy = value.getType().isa<LLVMArrayType>();
745     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
746       << value.getType() << ") ";
747   }
748 
749   p.printOptionalAttrDict(op->getAttrs(), {"cleanup"});
750 
751   p << ": " << op.getType();
752 }
753 
754 /// <operation> ::= `llvm.landingpad` `cleanup`?
755 ///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
756 static ParseResult parseLandingpadOp(OpAsmParser &parser,
757                                      OperationState &result) {
758   // Check for cleanup
759   if (succeeded(parser.parseOptionalKeyword("cleanup")))
760     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
761 
762   // Parse clauses with types
763   while (succeeded(parser.parseOptionalLParen()) &&
764          (succeeded(parser.parseOptionalKeyword("filter")) ||
765           succeeded(parser.parseOptionalKeyword("catch")))) {
766     OpAsmParser::OperandType operand;
767     Type ty;
768     if (parser.parseOperand(operand) || parser.parseColon() ||
769         parser.parseType(ty) ||
770         parser.resolveOperand(operand, ty, result.operands) ||
771         parser.parseRParen())
772       return failure();
773   }
774 
775   Type type;
776   if (parser.parseColon() || parser.parseType(type))
777     return failure();
778 
779   result.addTypes(type);
780   return success();
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Verifying/Printing/parsing for LLVM::CallOp.
785 //===----------------------------------------------------------------------===//
786 
787 static LogicalResult verify(CallOp &op) {
788   if (op.getNumResults() > 1)
789     return op.emitOpError("must have 0 or 1 result");
790 
791   // Type for the callee, we'll get it differently depending if it is a direct
792   // or indirect call.
793   Type fnType;
794 
795   bool isIndirect = false;
796 
797   // If this is an indirect call, the callee attribute is missing.
798   Optional<StringRef> calleeName = op.callee();
799   if (!calleeName) {
800     isIndirect = true;
801     if (!op.getNumOperands())
802       return op.emitOpError(
803           "must have either a `callee` attribute or at least an operand");
804     auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
805     if (!ptrType)
806       return op.emitOpError("indirect call expects a pointer as callee: ")
807              << ptrType;
808     fnType = ptrType.getElementType();
809   } else {
810     Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
811     if (!callee)
812       return op.emitOpError()
813              << "'" << *calleeName
814              << "' does not reference a symbol in the current scope";
815     auto fn = dyn_cast<LLVMFuncOp>(callee);
816     if (!fn)
817       return op.emitOpError() << "'" << *calleeName
818                               << "' does not reference a valid LLVM function";
819 
820     fnType = fn.getType();
821   }
822 
823   LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
824   if (!funcType)
825     return op.emitOpError("callee does not have a functional type: ") << fnType;
826 
827   // Verify that the operand and result types match the callee.
828 
829   if (!funcType.isVarArg() &&
830       funcType.getNumParams() != (op.getNumOperands() - isIndirect))
831     return op.emitOpError()
832            << "incorrect number of operands ("
833            << (op.getNumOperands() - isIndirect)
834            << ") for callee (expecting: " << funcType.getNumParams() << ")";
835 
836   if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
837     return op.emitOpError() << "incorrect number of operands ("
838                             << (op.getNumOperands() - isIndirect)
839                             << ") for varargs callee (expecting at least: "
840                             << funcType.getNumParams() << ")";
841 
842   for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
843     if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
844       return op.emitOpError() << "operand type mismatch for operand " << i
845                               << ": " << op.getOperand(i + isIndirect).getType()
846                               << " != " << funcType.getParamType(i);
847 
848   if (op.getNumResults() &&
849       op.getResult(0).getType() != funcType.getReturnType())
850     return op.emitOpError()
851            << "result type mismatch: " << op.getResult(0).getType()
852            << " != " << funcType.getReturnType();
853 
854   return success();
855 }
856 
857 static void printCallOp(OpAsmPrinter &p, CallOp &op) {
858   auto callee = op.callee();
859   bool isDirect = callee.hasValue();
860 
861   // Print the direct callee if present as a function attribute, or an indirect
862   // callee (first operand) otherwise.
863   p << op.getOperationName() << ' ';
864   if (isDirect)
865     p.printSymbolName(callee.getValue());
866   else
867     p << op.getOperand(0);
868 
869   auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
870   p << '(' << args << ')';
871   p.printOptionalAttrDict(processFMFAttr(op->getAttrs()), {"callee"});
872 
873   // Reconstruct the function MLIR function type from operand and result types.
874   p << " : "
875     << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes());
876 }
877 
878 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
879 //                 attribute-dict? `:` function-type
880 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
881   SmallVector<OpAsmParser::OperandType, 8> operands;
882   Type type;
883   SymbolRefAttr funcAttr;
884   llvm::SMLoc trailingTypeLoc;
885 
886   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
887   // case of an indirect call, there will be 1 operand before `(`.  In case of a
888   // direct call, there will be no operands and the parser will stop at the
889   // function identifier without complaining.
890   if (parser.parseOperandList(operands))
891     return failure();
892   bool isDirect = operands.empty();
893 
894   // Optionally parse a function identifier.
895   if (isDirect)
896     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
897       return failure();
898 
899   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
900       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
901       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
902     return failure();
903 
904   auto funcType = type.dyn_cast<FunctionType>();
905   if (!funcType)
906     return parser.emitError(trailingTypeLoc, "expected function type");
907   if (isDirect) {
908     // Make sure types match.
909     if (parser.resolveOperands(operands, funcType.getInputs(),
910                                parser.getNameLoc(), result.operands))
911       return failure();
912     result.addTypes(funcType.getResults());
913   } else {
914     // Construct the LLVM IR Dialect function type that the first operand
915     // should match.
916     if (funcType.getNumResults() > 1)
917       return parser.emitError(trailingTypeLoc,
918                               "expected function with 0 or 1 result");
919 
920     Builder &builder = parser.getBuilder();
921     Type llvmResultType;
922     if (funcType.getNumResults() == 0) {
923       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
924     } else {
925       llvmResultType = funcType.getResult(0);
926       if (!isCompatibleType(llvmResultType))
927         return parser.emitError(trailingTypeLoc,
928                                 "expected result to have LLVM type");
929     }
930 
931     SmallVector<Type, 8> argTypes;
932     argTypes.reserve(funcType.getNumInputs());
933     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
934       auto argType = funcType.getInput(i);
935       if (!isCompatibleType(argType))
936         return parser.emitError(trailingTypeLoc,
937                                 "expected LLVM types as inputs");
938       argTypes.push_back(argType);
939     }
940     auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
941     auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
942 
943     auto funcArguments =
944         ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
945 
946     // Make sure that the first operand (indirect callee) matches the wrapped
947     // LLVM IR function type, and that the types of the other call operands
948     // match the types of the function arguments.
949     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
950         parser.resolveOperands(funcArguments, funcType.getInputs(),
951                                parser.getNameLoc(), result.operands))
952       return failure();
953 
954     result.addTypes(llvmResultType);
955   }
956 
957   return success();
958 }
959 
960 //===----------------------------------------------------------------------===//
961 // Printing/parsing for LLVM::ExtractElementOp.
962 //===----------------------------------------------------------------------===//
963 // Expects vector to be of wrapped LLVM vector type and position to be of
964 // wrapped LLVM i32 type.
965 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
966                                    Value vector, Value position,
967                                    ArrayRef<NamedAttribute> attrs) {
968   auto vectorType = vector.getType();
969   auto llvmType = LLVM::getVectorElementType(vectorType);
970   build(b, result, llvmType, vector, position);
971   result.addAttributes(attrs);
972 }
973 
974 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
975   p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
976     << " : " << op.position().getType() << "]";
977   p.printOptionalAttrDict(op->getAttrs());
978   p << " : " << op.vector().getType();
979 }
980 
981 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
982 //                 attribute-dict? `:` type
983 static ParseResult parseExtractElementOp(OpAsmParser &parser,
984                                          OperationState &result) {
985   llvm::SMLoc loc;
986   OpAsmParser::OperandType vector, position;
987   Type type, positionType;
988   if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
989       parser.parseLSquare() || parser.parseOperand(position) ||
990       parser.parseColonType(positionType) || parser.parseRSquare() ||
991       parser.parseOptionalAttrDict(result.attributes) ||
992       parser.parseColonType(type) ||
993       parser.resolveOperand(vector, type, result.operands) ||
994       parser.resolveOperand(position, positionType, result.operands))
995     return failure();
996   if (!LLVM::isCompatibleVectorType(type))
997     return parser.emitError(
998         loc, "expected LLVM dialect-compatible vector type for operand #1");
999   result.addTypes(LLVM::getVectorElementType(type));
1000   return success();
1001 }
1002 
1003 //===----------------------------------------------------------------------===//
1004 // Printing/parsing for LLVM::ExtractValueOp.
1005 //===----------------------------------------------------------------------===//
1006 
1007 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
1008   p << op.getOperationName() << ' ' << op.container() << op.position();
1009   p.printOptionalAttrDict(op->getAttrs(), {"position"});
1010   p << " : " << op.container().getType();
1011 }
1012 
1013 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1014 // `containerType`.  Position is an integer array attribute where each value
1015 // is a zero-based position of the element in the aggregate type.  Return the
1016 // resulting type wrapped in MLIR, or nullptr on error.
1017 static Type getInsertExtractValueElementType(OpAsmParser &parser,
1018                                              Type containerType,
1019                                              ArrayAttr positionAttr,
1020                                              llvm::SMLoc attributeLoc,
1021                                              llvm::SMLoc typeLoc) {
1022   Type llvmType = containerType;
1023   if (!isCompatibleType(containerType))
1024     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
1025 
1026   // Infer the element type from the structure type: iteratively step inside the
1027   // type by taking the element type, indexed by the position attribute for
1028   // structures.  Check the position index before accessing, it is supposed to
1029   // be in bounds.
1030   for (Attribute subAttr : positionAttr) {
1031     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1032     if (!positionElementAttr)
1033       return parser.emitError(attributeLoc,
1034                               "expected an array of integer literals"),
1035              nullptr;
1036     int position = positionElementAttr.getInt();
1037     if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1038       if (position < 0 ||
1039           static_cast<unsigned>(position) >= arrayType.getNumElements())
1040         return parser.emitError(attributeLoc, "position out of bounds"),
1041                nullptr;
1042       llvmType = arrayType.getElementType();
1043     } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1044       if (position < 0 ||
1045           static_cast<unsigned>(position) >= structType.getBody().size())
1046         return parser.emitError(attributeLoc, "position out of bounds"),
1047                nullptr;
1048       llvmType = structType.getBody()[position];
1049     } else {
1050       return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
1051              nullptr;
1052     }
1053   }
1054   return llvmType;
1055 }
1056 
1057 // <operation> ::= `llvm.extractvalue` ssa-use
1058 //                 `[` integer-literal (`,` integer-literal)* `]`
1059 //                 attribute-dict? `:` type
1060 static ParseResult parseExtractValueOp(OpAsmParser &parser,
1061                                        OperationState &result) {
1062   OpAsmParser::OperandType container;
1063   Type containerType;
1064   ArrayAttr positionAttr;
1065   llvm::SMLoc attributeLoc, trailingTypeLoc;
1066 
1067   if (parser.parseOperand(container) ||
1068       parser.getCurrentLocation(&attributeLoc) ||
1069       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1070       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1071       parser.getCurrentLocation(&trailingTypeLoc) ||
1072       parser.parseType(containerType) ||
1073       parser.resolveOperand(container, containerType, result.operands))
1074     return failure();
1075 
1076   auto elementType = getInsertExtractValueElementType(
1077       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1078   if (!elementType)
1079     return failure();
1080 
1081   result.addTypes(elementType);
1082   return success();
1083 }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // Printing/parsing for LLVM::InsertElementOp.
1087 //===----------------------------------------------------------------------===//
1088 
1089 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
1090   p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
1091     << op.position() << " : " << op.position().getType() << "]";
1092   p.printOptionalAttrDict(op->getAttrs());
1093   p << " : " << op.vector().getType();
1094 }
1095 
1096 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
1097 //                 attribute-dict? `:` type
1098 static ParseResult parseInsertElementOp(OpAsmParser &parser,
1099                                         OperationState &result) {
1100   llvm::SMLoc loc;
1101   OpAsmParser::OperandType vector, value, position;
1102   Type vectorType, positionType;
1103   if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
1104       parser.parseComma() || parser.parseOperand(vector) ||
1105       parser.parseLSquare() || parser.parseOperand(position) ||
1106       parser.parseColonType(positionType) || parser.parseRSquare() ||
1107       parser.parseOptionalAttrDict(result.attributes) ||
1108       parser.parseColonType(vectorType))
1109     return failure();
1110 
1111   if (!LLVM::isCompatibleVectorType(vectorType))
1112     return parser.emitError(
1113         loc, "expected LLVM dialect-compatible vector type for operand #1");
1114   Type valueType = LLVM::getVectorElementType(vectorType);
1115   if (!valueType)
1116     return failure();
1117 
1118   if (parser.resolveOperand(vector, vectorType, result.operands) ||
1119       parser.resolveOperand(value, valueType, result.operands) ||
1120       parser.resolveOperand(position, positionType, result.operands))
1121     return failure();
1122 
1123   result.addTypes(vectorType);
1124   return success();
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // Printing/parsing for LLVM::InsertValueOp.
1129 //===----------------------------------------------------------------------===//
1130 
1131 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
1132   p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
1133     << op.position();
1134   p.printOptionalAttrDict(op->getAttrs(), {"position"});
1135   p << " : " << op.container().getType();
1136 }
1137 
1138 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
1139 //                 `[` integer-literal (`,` integer-literal)* `]`
1140 //                 attribute-dict? `:` type
1141 static ParseResult parseInsertValueOp(OpAsmParser &parser,
1142                                       OperationState &result) {
1143   OpAsmParser::OperandType container, value;
1144   Type containerType;
1145   ArrayAttr positionAttr;
1146   llvm::SMLoc attributeLoc, trailingTypeLoc;
1147 
1148   if (parser.parseOperand(value) || parser.parseComma() ||
1149       parser.parseOperand(container) ||
1150       parser.getCurrentLocation(&attributeLoc) ||
1151       parser.parseAttribute(positionAttr, "position", result.attributes) ||
1152       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1153       parser.getCurrentLocation(&trailingTypeLoc) ||
1154       parser.parseType(containerType))
1155     return failure();
1156 
1157   auto valueType = getInsertExtractValueElementType(
1158       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1159   if (!valueType)
1160     return failure();
1161 
1162   if (parser.resolveOperand(container, containerType, result.operands) ||
1163       parser.resolveOperand(value, valueType, result.operands))
1164     return failure();
1165 
1166   result.addTypes(containerType);
1167   return success();
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // Printing, parsing and verification for LLVM::ReturnOp.
1172 //===----------------------------------------------------------------------===//
1173 
1174 static void printReturnOp(OpAsmPrinter &p, ReturnOp op) {
1175   p << op.getOperationName();
1176   p.printOptionalAttrDict(op->getAttrs());
1177   assert(op.getNumOperands() <= 1);
1178 
1179   if (op.getNumOperands() == 0)
1180     return;
1181 
1182   p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
1183 }
1184 
1185 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
1186 //                 type-list-no-parens
1187 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
1188   SmallVector<OpAsmParser::OperandType, 1> operands;
1189   Type type;
1190 
1191   if (parser.parseOperandList(operands) ||
1192       parser.parseOptionalAttrDict(result.attributes))
1193     return failure();
1194   if (operands.empty())
1195     return success();
1196 
1197   if (parser.parseColonType(type) ||
1198       parser.resolveOperand(operands[0], type, result.operands))
1199     return failure();
1200   return success();
1201 }
1202 
1203 static LogicalResult verify(ReturnOp op) {
1204   if (op->getNumOperands() > 1)
1205     return op->emitOpError("expected at most 1 operand");
1206 
1207   if (auto parent = op->getParentOfType<LLVMFuncOp>()) {
1208     Type expectedType = parent.getType().getReturnType();
1209     if (expectedType.isa<LLVMVoidType>()) {
1210       if (op->getNumOperands() == 0)
1211         return success();
1212       InFlightDiagnostic diag = op->emitOpError("expected no operands");
1213       diag.attachNote(parent->getLoc()) << "when returning from function";
1214       return diag;
1215     }
1216     if (op->getNumOperands() == 0) {
1217       if (expectedType.isa<LLVMVoidType>())
1218         return success();
1219       InFlightDiagnostic diag = op->emitOpError("expected 1 operand");
1220       diag.attachNote(parent->getLoc()) << "when returning from function";
1221       return diag;
1222     }
1223     if (expectedType != op->getOperand(0).getType()) {
1224       InFlightDiagnostic diag = op->emitOpError("mismatching result types");
1225       diag.attachNote(parent->getLoc()) << "when returning from function";
1226       return diag;
1227     }
1228   }
1229   return success();
1230 }
1231 
1232 //===----------------------------------------------------------------------===//
1233 // Verifier for LLVM::AddressOfOp.
1234 //===----------------------------------------------------------------------===//
1235 
1236 template <typename OpTy>
1237 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
1238   Operation *module = parent;
1239   while (module && !satisfiesLLVMModule(module))
1240     module = module->getParentOp();
1241   assert(module && "unexpected operation outside of a module");
1242   return dyn_cast_or_null<OpTy>(
1243       mlir::SymbolTable::lookupSymbolIn(module, name));
1244 }
1245 
1246 GlobalOp AddressOfOp::getGlobal() {
1247   return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
1248                                               global_name());
1249 }
1250 
1251 LLVMFuncOp AddressOfOp::getFunction() {
1252   return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
1253                                                 global_name());
1254 }
1255 
1256 static LogicalResult verify(AddressOfOp op) {
1257   auto global = op.getGlobal();
1258   auto function = op.getFunction();
1259   if (!global && !function)
1260     return op.emitOpError(
1261         "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1262 
1263   if (global &&
1264       LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) !=
1265           op.getResult().getType())
1266     return op.emitOpError(
1267         "the type must be a pointer to the type of the referenced global");
1268 
1269   if (function && LLVM::LLVMPointerType::get(function.getType()) !=
1270                       op.getResult().getType())
1271     return op.emitOpError(
1272         "the type must be a pointer to the type of the referenced function");
1273 
1274   return success();
1275 }
1276 
1277 //===----------------------------------------------------------------------===//
1278 // Builder, printer and verifier for LLVM::GlobalOp.
1279 //===----------------------------------------------------------------------===//
1280 
1281 /// Returns the name used for the linkage attribute. This *must* correspond to
1282 /// the name of the attribute in ODS.
1283 static StringRef getLinkageAttrName() { return "linkage"; }
1284 
1285 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1286                      bool isConstant, Linkage linkage, StringRef name,
1287                      Attribute value, unsigned addrSpace,
1288                      ArrayRef<NamedAttribute> attrs) {
1289   result.addAttribute(SymbolTable::getSymbolAttrName(),
1290                       builder.getStringAttr(name));
1291   result.addAttribute("type", TypeAttr::get(type));
1292   if (isConstant)
1293     result.addAttribute("constant", builder.getUnitAttr());
1294   if (value)
1295     result.addAttribute("value", value);
1296   result.addAttribute(getLinkageAttrName(),
1297                       builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1298   if (addrSpace != 0)
1299     result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
1300   result.attributes.append(attrs.begin(), attrs.end());
1301   result.addRegion();
1302 }
1303 
1304 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
1305   p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
1306   if (op.constant())
1307     p << "constant ";
1308   p.printSymbolName(op.sym_name());
1309   p << '(';
1310   if (auto value = op.getValueOrNull())
1311     p.printAttribute(value);
1312   p << ')';
1313   p.printOptionalAttrDict(op->getAttrs(),
1314                           {SymbolTable::getSymbolAttrName(), "type", "constant",
1315                            "value", getLinkageAttrName()});
1316 
1317   // Print the trailing type unless it's a string global.
1318   if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
1319     return;
1320   p << " : " << op.type();
1321 
1322   Region &initializer = op.getInitializerRegion();
1323   if (!initializer.empty())
1324     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1325 }
1326 
1327 //===----------------------------------------------------------------------===//
1328 // Verifier for LLVM::DialectCastOp.
1329 //===----------------------------------------------------------------------===//
1330 
1331 /// Checks if `llvmType` is dialect cast-compatible with `index` type. Does not
1332 /// report the error, the user is expected to produce an appropriate message.
1333 // TODO: make the size depend on data layout rather than on the conversion
1334 // pass option, and pull that information here.
1335 static LogicalResult verifyCastWithIndex(Type llvmType) {
1336   return success(llvmType.isa<IntegerType>());
1337 }
1338 
1339 /// Checks if `llvmType` is dialect cast-compatible with built-in `type` and
1340 /// reports errors to the location of `op`. `isElement` indicates whether the
1341 /// verification is performed for types that are element types inside a
1342 /// container; we don't want casts from X to X at the top level, but c1<X> to
1343 /// c2<X> may be fine.
1344 static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
1345                                 bool isElement = false) {
1346   // Equal element types are directly compatible.
1347   if (isElement && llvmType == type)
1348     return success();
1349 
1350   // Index is compatible with any integer.
1351   if (type.isIndex()) {
1352     if (succeeded(verifyCastWithIndex(llvmType)))
1353       return success();
1354 
1355     return op.emitOpError("invalid cast between index and non-integer type");
1356   }
1357 
1358   if (type.isa<IntegerType>()) {
1359     auto llvmIntegerType = llvmType.dyn_cast<IntegerType>();
1360     if (!llvmIntegerType)
1361       return op->emitOpError("invalid cast between integer and non-integer");
1362     if (llvmIntegerType.getWidth() != type.getIntOrFloatBitWidth())
1363       return op.emitOpError("invalid cast changing integer width");
1364     return success();
1365   }
1366 
1367   // Vectors are compatible if they are 1D non-scalable, and their element types
1368   // are compatible. nD vectors are compatible with (n-1)D arrays containing 1D
1369   // vector.
1370   if (auto vectorType = type.dyn_cast<VectorType>()) {
1371     if (vectorType == llvmType && !isElement)
1372       return op.emitOpError("vector types should not be casted");
1373 
1374     if (vectorType.getRank() == 1) {
1375       auto llvmVectorType = llvmType.dyn_cast<VectorType>();
1376       if (!llvmVectorType || llvmVectorType.getRank() != 1)
1377         return op.emitOpError("invalid cast for vector types");
1378 
1379       return verifyCast(op, llvmVectorType.getElementType(),
1380                         vectorType.getElementType(), /*isElement=*/true);
1381     }
1382 
1383     auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>();
1384     if (!arrayType ||
1385         arrayType.getNumElements() != vectorType.getShape().front())
1386       return op.emitOpError("invalid cast for vector, expected array");
1387     return verifyCast(op, arrayType.getElementType(),
1388                       VectorType::get(vectorType.getShape().drop_front(),
1389                                       vectorType.getElementType()),
1390                       /*isElement=*/true);
1391   }
1392 
1393   if (auto memrefType = type.dyn_cast<MemRefType>()) {
1394     // Bare pointer convention: statically-shaped memref is compatible with an
1395     // LLVM pointer to the element type.
1396     if (auto ptrType = llvmType.dyn_cast<LLVMPointerType>()) {
1397       if (!memrefType.hasStaticShape())
1398         return op->emitOpError(
1399             "unexpected bare pointer for dynamically shaped memref");
1400       if (memrefType.getMemorySpaceAsInt() != ptrType.getAddressSpace())
1401         return op->emitError("invalid conversion between memref and pointer in "
1402                              "different memory spaces");
1403 
1404       return verifyCast(op, ptrType.getElementType(),
1405                         memrefType.getElementType(), /*isElement=*/true);
1406     }
1407 
1408     // Otherwise, memrefs are convertible to a descriptor, which is a structure
1409     // type.
1410     auto structType = llvmType.dyn_cast<LLVMStructType>();
1411     if (!structType)
1412       return op->emitOpError("invalid cast between a memref and a type other "
1413                              "than pointer or memref descriptor");
1414 
1415     unsigned expectedNumElements = memrefType.getRank() == 0 ? 3 : 5;
1416     if (structType.getBody().size() != expectedNumElements) {
1417       return op->emitOpError() << "expected memref descriptor with "
1418                                << expectedNumElements << " elements";
1419     }
1420 
1421     // The first two elements are pointers to the element type.
1422     auto allocatedPtr = structType.getBody()[0].dyn_cast<LLVMPointerType>();
1423     if (!allocatedPtr ||
1424         allocatedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt())
1425       return op->emitOpError("expected first element of a memref descriptor to "
1426                              "be a pointer in the address space of the memref");
1427     if (failed(verifyCast(op, allocatedPtr.getElementType(),
1428                           memrefType.getElementType(), /*isElement=*/true)))
1429       return failure();
1430 
1431     auto alignedPtr = structType.getBody()[1].dyn_cast<LLVMPointerType>();
1432     if (!alignedPtr ||
1433         alignedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt())
1434       return op->emitOpError(
1435           "expected second element of a memref descriptor to "
1436           "be a pointer in the address space of the memref");
1437     if (failed(verifyCast(op, alignedPtr.getElementType(),
1438                           memrefType.getElementType(), /*isElement=*/true)))
1439       return failure();
1440 
1441     // The second element (offset) is an equivalent of index.
1442     if (failed(verifyCastWithIndex(structType.getBody()[2])))
1443       return op->emitOpError("expected third element of a memref descriptor to "
1444                              "be index-compatible integers");
1445 
1446     // 0D memrefs don't have sizes/strides.
1447     if (memrefType.getRank() == 0)
1448       return success();
1449 
1450     // Sizes and strides are rank-sized arrays of `index` equivalents.
1451     auto sizes = structType.getBody()[3].dyn_cast<LLVMArrayType>();
1452     if (!sizes || failed(verifyCastWithIndex(sizes.getElementType())) ||
1453         sizes.getNumElements() != memrefType.getRank())
1454       return op->emitOpError(
1455           "expected fourth element of a memref descriptor "
1456           "to be an array of <rank> index-compatible integers");
1457 
1458     auto strides = structType.getBody()[4].dyn_cast<LLVMArrayType>();
1459     if (!strides || failed(verifyCastWithIndex(strides.getElementType())) ||
1460         strides.getNumElements() != memrefType.getRank())
1461       return op->emitOpError(
1462           "expected fifth element of a memref descriptor "
1463           "to be an array of <rank> index-compatible integers");
1464 
1465     return success();
1466   }
1467 
1468   // Unranked memrefs are compatible with their descriptors.
1469   if (auto unrankedMemrefType = type.dyn_cast<UnrankedMemRefType>()) {
1470     auto structType = llvmType.dyn_cast<LLVMStructType>();
1471     if (!structType || structType.getBody().size() != 2)
1472       return op->emitOpError(
1473           "expected descriptor to be a struct with two elements");
1474 
1475     if (failed(verifyCastWithIndex(structType.getBody()[0])))
1476       return op->emitOpError("expected first element of a memref descriptor to "
1477                              "be an index-compatible integer");
1478 
1479     auto ptrType = structType.getBody()[1].dyn_cast<LLVMPointerType>();
1480     auto ptrElementType =
1481         ptrType ? ptrType.getElementType().dyn_cast<IntegerType>() : nullptr;
1482     if (!ptrElementType || ptrElementType.getWidth() != 8)
1483       return op->emitOpError("expected second element of a memref descriptor "
1484                              "to be an !llvm.ptr<i8>");
1485 
1486     return success();
1487   }
1488 
1489   // Complex types are compatible with the two-element structs.
1490   if (auto complexType = type.dyn_cast<ComplexType>()) {
1491     auto structType = llvmType.dyn_cast<LLVMStructType>();
1492     if (!structType || structType.getBody().size() != 2 ||
1493         structType.getBody()[0] != structType.getBody()[1] ||
1494         structType.getBody()[0] != complexType.getElementType())
1495       return op->emitOpError("expected 'complex' to map to two-element struct "
1496                              "with identical element types");
1497     return success();
1498   }
1499 
1500   // Everything else is not supported.
1501   return op->emitError("unsupported cast");
1502 }
1503 
1504 static LogicalResult verify(DialectCastOp op) {
1505   if (isCompatibleType(op.getType()))
1506     return verifyCast(op, op.getType(), op.in().getType());
1507 
1508   if (!isCompatibleType(op.in().getType()))
1509     return op->emitOpError("expected one LLVM type and one built-in type");
1510 
1511   return verifyCast(op, op.in().getType(), op.getType());
1512 }
1513 
1514 // Parses one of the keywords provided in the list `keywords` and returns the
1515 // position of the parsed keyword in the list. If none of the keywords from the
1516 // list is parsed, returns -1.
1517 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1518                                            ArrayRef<StringRef> keywords) {
1519   for (auto en : llvm::enumerate(keywords)) {
1520     if (succeeded(parser.parseOptionalKeyword(en.value())))
1521       return en.index();
1522   }
1523   return -1;
1524 }
1525 
1526 namespace {
1527 template <typename Ty> struct EnumTraits {};
1528 
1529 #define REGISTER_ENUM_TYPE(Ty)                                                 \
1530   template <> struct EnumTraits<Ty> {                                          \
1531     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
1532     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
1533   }
1534 
1535 REGISTER_ENUM_TYPE(Linkage);
1536 } // end namespace
1537 
1538 template <typename EnumTy>
1539 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
1540                                             OperationState &result,
1541                                             StringRef name) {
1542   SmallVector<StringRef, 10> names;
1543   for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
1544     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1545 
1546   int index = parseOptionalKeywordAlternative(parser, names);
1547   if (index == -1)
1548     return failure();
1549   result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
1550   return success();
1551 }
1552 
1553 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1554 //               `(` attribute? `)` attribute-list? (`:` type)? region?
1555 //
1556 // The type can be omitted for string attributes, in which case it will be
1557 // inferred from the value of the string as [strlen(value) x i8].
1558 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
1559   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1560                                                getLinkageAttrName())))
1561     result.addAttribute(getLinkageAttrName(),
1562                         parser.getBuilder().getI64IntegerAttr(
1563                             static_cast<int64_t>(LLVM::Linkage::External)));
1564 
1565   if (succeeded(parser.parseOptionalKeyword("constant")))
1566     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1567 
1568   StringAttr name;
1569   if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1570                              result.attributes) ||
1571       parser.parseLParen())
1572     return failure();
1573 
1574   Attribute value;
1575   if (parser.parseOptionalRParen()) {
1576     if (parser.parseAttribute(value, "value", result.attributes) ||
1577         parser.parseRParen())
1578       return failure();
1579   }
1580 
1581   SmallVector<Type, 1> types;
1582   if (parser.parseOptionalAttrDict(result.attributes) ||
1583       parser.parseOptionalColonTypeList(types))
1584     return failure();
1585 
1586   if (types.size() > 1)
1587     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1588 
1589   Region &initRegion = *result.addRegion();
1590   if (types.empty()) {
1591     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1592       MLIRContext *context = parser.getBuilder().getContext();
1593       auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1594                                                 strAttr.getValue().size());
1595       types.push_back(arrayType);
1596     } else {
1597       return parser.emitError(parser.getNameLoc(),
1598                               "type can only be omitted for string globals");
1599     }
1600   } else {
1601     OptionalParseResult parseResult =
1602         parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1603                                    /*argTypes=*/{});
1604     if (parseResult.hasValue() && failed(*parseResult))
1605       return failure();
1606   }
1607 
1608   result.addAttribute("type", TypeAttr::get(types[0]));
1609   return success();
1610 }
1611 
1612 static bool isZeroAttribute(Attribute value) {
1613   if (auto intValue = value.dyn_cast<IntegerAttr>())
1614     return intValue.getValue().isNullValue();
1615   if (auto fpValue = value.dyn_cast<FloatAttr>())
1616     return fpValue.getValue().isZero();
1617   if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
1618     return isZeroAttribute(splatValue.getSplatValue());
1619   if (auto elementsValue = value.dyn_cast<ElementsAttr>())
1620     return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1621   if (auto arrayValue = value.dyn_cast<ArrayAttr>())
1622     return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1623   return false;
1624 }
1625 
1626 static LogicalResult verify(GlobalOp op) {
1627   if (!LLVMPointerType::isValidElementType(op.getType()))
1628     return op.emitOpError(
1629         "expects type to be a valid element type for an LLVM pointer");
1630   if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
1631     return op.emitOpError("must appear at the module level");
1632 
1633   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1634     auto type = op.getType().dyn_cast<LLVMArrayType>();
1635     IntegerType elementType =
1636         type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
1637     if (!elementType || elementType.getWidth() != 8 ||
1638         type.getNumElements() != strAttr.getValue().size())
1639       return op.emitOpError(
1640           "requires an i8 array type of the length equal to that of the string "
1641           "attribute");
1642   }
1643 
1644   if (Block *b = op.getInitializerBlock()) {
1645     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1646     if (ret.operand_type_begin() == ret.operand_type_end())
1647       return op.emitOpError("initializer region cannot return void");
1648     if (*ret.operand_type_begin() != op.getType())
1649       return op.emitOpError("initializer region type ")
1650              << *ret.operand_type_begin() << " does not match global type "
1651              << op.getType();
1652 
1653     if (op.getValueOrNull())
1654       return op.emitOpError("cannot have both initializer value and region");
1655   }
1656 
1657   if (op.linkage() == Linkage::Common) {
1658     if (Attribute value = op.getValueOrNull()) {
1659       if (!isZeroAttribute(value)) {
1660         return op.emitOpError()
1661                << "expected zero value for '"
1662                << stringifyLinkage(Linkage::Common) << "' linkage";
1663       }
1664     }
1665   }
1666 
1667   if (op.linkage() == Linkage::Appending) {
1668     if (!op.getType().isa<LLVMArrayType>()) {
1669       return op.emitOpError()
1670              << "expected array type for '"
1671              << stringifyLinkage(Linkage::Appending) << "' linkage";
1672     }
1673   }
1674 
1675   return success();
1676 }
1677 
1678 //===----------------------------------------------------------------------===//
1679 // Printing/parsing for LLVM::ShuffleVectorOp.
1680 //===----------------------------------------------------------------------===//
1681 // Expects vector to be of wrapped LLVM vector type and position to be of
1682 // wrapped LLVM i32 type.
1683 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
1684                                   Value v1, Value v2, ArrayAttr mask,
1685                                   ArrayRef<NamedAttribute> attrs) {
1686   auto containerType = v1.getType();
1687   auto vType = LLVM::getFixedVectorType(
1688       LLVM::getVectorElementType(containerType), mask.size());
1689   build(b, result, vType, v1, v2, mask);
1690   result.addAttributes(attrs);
1691 }
1692 
1693 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
1694   p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
1695     << op.mask();
1696   p.printOptionalAttrDict(op->getAttrs(), {"mask"});
1697   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1698 }
1699 
1700 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1701 //                 `[` integer-literal (`,` integer-literal)* `]`
1702 //                 attribute-dict? `:` type
1703 static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
1704                                         OperationState &result) {
1705   llvm::SMLoc loc;
1706   OpAsmParser::OperandType v1, v2;
1707   ArrayAttr maskAttr;
1708   Type typeV1, typeV2;
1709   if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1710       parser.parseComma() || parser.parseOperand(v2) ||
1711       parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1712       parser.parseOptionalAttrDict(result.attributes) ||
1713       parser.parseColonType(typeV1) || parser.parseComma() ||
1714       parser.parseType(typeV2) ||
1715       parser.resolveOperand(v1, typeV1, result.operands) ||
1716       parser.resolveOperand(v2, typeV2, result.operands))
1717     return failure();
1718   if (!LLVM::isCompatibleVectorType(typeV1))
1719     return parser.emitError(
1720         loc, "expected LLVM IR dialect vector type for operand #1");
1721   auto vType = LLVM::getFixedVectorType(LLVM::getVectorElementType(typeV1),
1722                                         maskAttr.size());
1723   result.addTypes(vType);
1724   return success();
1725 }
1726 
1727 //===----------------------------------------------------------------------===//
1728 // Implementations for LLVM::LLVMFuncOp.
1729 //===----------------------------------------------------------------------===//
1730 
1731 // Add the entry block to the function.
1732 Block *LLVMFuncOp::addEntryBlock() {
1733   assert(empty() && "function already has an entry block");
1734   assert(!isVarArg() && "unimplemented: non-external variadic functions");
1735 
1736   auto *entry = new Block;
1737   push_back(entry);
1738 
1739   LLVMFunctionType type = getType();
1740   for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
1741     entry->addArgument(type.getParamType(i));
1742   return entry;
1743 }
1744 
1745 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
1746                        StringRef name, Type type, LLVM::Linkage linkage,
1747                        ArrayRef<NamedAttribute> attrs,
1748                        ArrayRef<DictionaryAttr> argAttrs) {
1749   result.addRegion();
1750   result.addAttribute(SymbolTable::getSymbolAttrName(),
1751                       builder.getStringAttr(name));
1752   result.addAttribute("type", TypeAttr::get(type));
1753   result.addAttribute(getLinkageAttrName(),
1754                       builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1755   result.attributes.append(attrs.begin(), attrs.end());
1756   if (argAttrs.empty())
1757     return;
1758 
1759   unsigned numInputs = type.cast<LLVMFunctionType>().getNumParams();
1760   assert(numInputs == argAttrs.size() &&
1761          "expected as many argument attribute lists as arguments");
1762   SmallString<8> argAttrName;
1763   for (unsigned i = 0; i < numInputs; ++i)
1764     if (DictionaryAttr argDict = argAttrs[i])
1765       result.addAttribute(getArgAttrName(i, argAttrName), argDict);
1766 }
1767 
1768 // Builds an LLVM function type from the given lists of input and output types.
1769 // Returns a null type if any of the types provided are non-LLVM types, or if
1770 // there is more than one output type.
1771 static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
1772                                   ArrayRef<Type> inputs, ArrayRef<Type> outputs,
1773                                   impl::VariadicFlag variadicFlag) {
1774   Builder &b = parser.getBuilder();
1775   if (outputs.size() > 1) {
1776     parser.emitError(loc, "failed to construct function type: expected zero or "
1777                           "one function result");
1778     return {};
1779   }
1780 
1781   // Convert inputs to LLVM types, exit early on error.
1782   SmallVector<Type, 4> llvmInputs;
1783   for (auto t : inputs) {
1784     if (!isCompatibleType(t)) {
1785       parser.emitError(loc, "failed to construct function type: expected LLVM "
1786                             "type for function arguments");
1787       return {};
1788     }
1789     llvmInputs.push_back(t);
1790   }
1791 
1792   // No output is denoted as "void" in LLVM type system.
1793   Type llvmOutput =
1794       outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
1795   if (!isCompatibleType(llvmOutput)) {
1796     parser.emitError(loc, "failed to construct function type: expected LLVM "
1797                           "type for function results")
1798         << llvmOutput;
1799     return {};
1800   }
1801   return LLVMFunctionType::get(llvmOutput, llvmInputs,
1802                                variadicFlag.isVariadic());
1803 }
1804 
1805 // Parses an LLVM function.
1806 //
1807 // operation ::= `llvm.func` linkage? function-signature function-attributes?
1808 //               function-body
1809 //
1810 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
1811                                    OperationState &result) {
1812   // Default to external linkage if no keyword is provided.
1813   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1814                                                getLinkageAttrName())))
1815     result.addAttribute(getLinkageAttrName(),
1816                         parser.getBuilder().getI64IntegerAttr(
1817                             static_cast<int64_t>(LLVM::Linkage::External)));
1818 
1819   StringAttr nameAttr;
1820   SmallVector<OpAsmParser::OperandType, 8> entryArgs;
1821   SmallVector<NamedAttrList, 1> argAttrs;
1822   SmallVector<NamedAttrList, 1> resultAttrs;
1823   SmallVector<Type, 8> argTypes;
1824   SmallVector<Type, 4> resultTypes;
1825   bool isVariadic;
1826 
1827   auto signatureLocation = parser.getCurrentLocation();
1828   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1829                              result.attributes) ||
1830       impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
1831                                    argTypes, argAttrs, isVariadic, resultTypes,
1832                                    resultAttrs))
1833     return failure();
1834 
1835   auto type =
1836       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
1837                             impl::VariadicFlag(isVariadic));
1838   if (!type)
1839     return failure();
1840   result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
1841 
1842   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1843     return failure();
1844   impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
1845                              resultAttrs);
1846 
1847   auto *body = result.addRegion();
1848   OptionalParseResult parseResult = parser.parseOptionalRegion(
1849       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1850   return failure(parseResult.hasValue() && failed(*parseResult));
1851 }
1852 
1853 // Print the LLVMFuncOp. Collects argument and result types and passes them to
1854 // helper functions. Drops "void" result since it cannot be parsed back. Skips
1855 // the external linkage since it is the default value.
1856 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
1857   p << op.getOperationName() << ' ';
1858   if (op.linkage() != LLVM::Linkage::External)
1859     p << stringifyLinkage(op.linkage()) << ' ';
1860   p.printSymbolName(op.getName());
1861 
1862   LLVMFunctionType fnType = op.getType();
1863   SmallVector<Type, 8> argTypes;
1864   SmallVector<Type, 1> resTypes;
1865   argTypes.reserve(fnType.getNumParams());
1866   for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
1867     argTypes.push_back(fnType.getParamType(i));
1868 
1869   Type returnType = fnType.getReturnType();
1870   if (!returnType.isa<LLVMVoidType>())
1871     resTypes.push_back(returnType);
1872 
1873   impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
1874   impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
1875                                 {getLinkageAttrName()});
1876 
1877   // Print the body if this is not an external function.
1878   Region &body = op.body();
1879   if (!body.empty())
1880     p.printRegion(body, /*printEntryBlockArgs=*/false,
1881                   /*printBlockTerminators=*/true);
1882 }
1883 
1884 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1885 // attribute is present.  This can check for preconditions of the
1886 // getNumArguments hook not failing.
1887 LogicalResult LLVMFuncOp::verifyType() {
1888   auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMFunctionType>();
1889   if (!llvmType)
1890     return emitOpError("requires '" + getTypeAttrName() +
1891                        "' attribute of wrapped LLVM function type");
1892 
1893   return success();
1894 }
1895 
1896 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
1897 // Depends on the type attribute being correct as checked by verifyType
1898 unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); }
1899 
1900 // Hook for OpTrait::FunctionLike, returns the number of function results.
1901 // Depends on the type attribute being correct as checked by verifyType
1902 unsigned LLVMFuncOp::getNumFuncResults() {
1903   // We model LLVM functions that return void as having zero results,
1904   // and all others as having one result.
1905   // If we modeled a void return as one result, then it would be possible to
1906   // attach an MLIR result attribute to it, and it isn't clear what semantics we
1907   // would assign to that.
1908   if (getType().getReturnType().isa<LLVMVoidType>())
1909     return 0;
1910   return 1;
1911 }
1912 
1913 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
1914 // - functions don't have 'common' linkage
1915 // - external functions have 'external' or 'extern_weak' linkage;
1916 // - vararg is (currently) only supported for external functions;
1917 // - entry block arguments are of LLVM types and match the function signature.
1918 static LogicalResult verify(LLVMFuncOp op) {
1919   if (op.linkage() == LLVM::Linkage::Common)
1920     return op.emitOpError()
1921            << "functions cannot have '"
1922            << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
1923 
1924   if (op.isExternal()) {
1925     if (op.linkage() != LLVM::Linkage::External &&
1926         op.linkage() != LLVM::Linkage::ExternWeak)
1927       return op.emitOpError()
1928              << "external functions must have '"
1929              << stringifyLinkage(LLVM::Linkage::External) << "' or '"
1930              << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
1931     return success();
1932   }
1933 
1934   if (op.isVarArg())
1935     return op.emitOpError("only external functions can be variadic");
1936 
1937   unsigned numArguments = op.getType().getNumParams();
1938   Block &entryBlock = op.front();
1939   for (unsigned i = 0; i < numArguments; ++i) {
1940     Type argType = entryBlock.getArgument(i).getType();
1941     if (!isCompatibleType(argType))
1942       return op.emitOpError("entry block argument #")
1943              << i << " is not of LLVM type";
1944     if (op.getType().getParamType(i) != argType)
1945       return op.emitOpError("the type of entry block argument #")
1946              << i << " does not match the function signature";
1947   }
1948 
1949   return success();
1950 }
1951 
1952 //===----------------------------------------------------------------------===//
1953 // Verification for LLVM::ConstantOp.
1954 //===----------------------------------------------------------------------===//
1955 
1956 static LogicalResult verify(LLVM::ConstantOp op) {
1957   if (StringAttr sAttr = op.value().dyn_cast<StringAttr>()) {
1958     auto arrayType = op.getType().dyn_cast<LLVMArrayType>();
1959     if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
1960         !arrayType.getElementType().isInteger(8)) {
1961       return op->emitOpError()
1962              << "expected array type of " << sAttr.getValue().size()
1963              << " i8 elements for the string constant";
1964     }
1965     return success();
1966   }
1967   if (!op.value().isa<IntegerAttr, FloatAttr, ElementsAttr>())
1968     return op.emitOpError()
1969            << "only supports integer, float, string or elements attributes";
1970   return success();
1971 }
1972 
1973 //===----------------------------------------------------------------------===//
1974 // Utility functions for parsing atomic ops
1975 //===----------------------------------------------------------------------===//
1976 
1977 // Helper function to parse a keyword into the specified attribute named by
1978 // `attrName`. The keyword must match one of the string values defined by the
1979 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
1980 // state.
1981 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
1982                                     StringRef attrName) {
1983   llvm::SMLoc loc;
1984   StringRef keyword;
1985   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
1986     return failure();
1987 
1988   // Replace the keyword `keyword` with an integer attribute.
1989   auto kind = symbolizeAtomicBinOp(keyword);
1990   if (!kind) {
1991     return parser.emitError(loc)
1992            << "'" << keyword << "' is an incorrect value of the '" << attrName
1993            << "' attribute";
1994   }
1995 
1996   auto value = static_cast<int64_t>(kind.getValue());
1997   auto attr = parser.getBuilder().getI64IntegerAttr(value);
1998   result.addAttribute(attrName, attr);
1999 
2000   return success();
2001 }
2002 
2003 // Helper function to parse a keyword into the specified attribute named by
2004 // `attrName`. The keyword must match one of the string values defined by the
2005 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
2006 // state.
2007 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
2008                                        OperationState &result,
2009                                        StringRef attrName) {
2010   llvm::SMLoc loc;
2011   StringRef ordering;
2012   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
2013     return failure();
2014 
2015   // Replace the keyword `ordering` with an integer attribute.
2016   auto kind = symbolizeAtomicOrdering(ordering);
2017   if (!kind) {
2018     return parser.emitError(loc)
2019            << "'" << ordering << "' is an incorrect value of the '" << attrName
2020            << "' attribute";
2021   }
2022 
2023   auto value = static_cast<int64_t>(kind.getValue());
2024   auto attr = parser.getBuilder().getI64IntegerAttr(value);
2025   result.addAttribute(attrName, attr);
2026 
2027   return success();
2028 }
2029 
2030 //===----------------------------------------------------------------------===//
2031 // Printer, parser and verifier for LLVM::AtomicRMWOp.
2032 //===----------------------------------------------------------------------===//
2033 
2034 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
2035   p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' '
2036     << op.ptr() << ", " << op.val() << ' '
2037     << stringifyAtomicOrdering(op.ordering()) << ' ';
2038   p.printOptionalAttrDict(op->getAttrs(), {"bin_op", "ordering"});
2039   p << " : " << op.res().getType();
2040 }
2041 
2042 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
2043 //                 attribute-dict? `:` type
2044 static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
2045                                     OperationState &result) {
2046   Type type;
2047   OpAsmParser::OperandType ptr, val;
2048   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
2049       parser.parseComma() || parser.parseOperand(val) ||
2050       parseAtomicOrdering(parser, result, "ordering") ||
2051       parser.parseOptionalAttrDict(result.attributes) ||
2052       parser.parseColonType(type) ||
2053       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2054                             result.operands) ||
2055       parser.resolveOperand(val, type, result.operands))
2056     return failure();
2057 
2058   result.addTypes(type);
2059   return success();
2060 }
2061 
2062 static LogicalResult verify(AtomicRMWOp op) {
2063   auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
2064   auto valType = op.val().getType();
2065   if (valType != ptrType.getElementType())
2066     return op.emitOpError("expected LLVM IR element type for operand #0 to "
2067                           "match type for operand #1");
2068   auto resType = op.res().getType();
2069   if (resType != valType)
2070     return op.emitOpError(
2071         "expected LLVM IR result type to match type for operand #1");
2072   if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
2073     if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2074       return op.emitOpError("expected LLVM IR floating point type");
2075   } else if (op.bin_op() == AtomicBinOp::xchg) {
2076     auto intType = valType.dyn_cast<IntegerType>();
2077     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2078     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2079         intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
2080         !valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
2081         !valType.isa<Float64Type>())
2082       return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2083   } else {
2084     auto intType = valType.dyn_cast<IntegerType>();
2085     unsigned intBitWidth = intType ? intType.getWidth() : 0;
2086     if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2087         intBitWidth != 64)
2088       return op.emitOpError("expected LLVM IR integer type");
2089   }
2090 
2091   if (static_cast<unsigned>(op.ordering()) <
2092       static_cast<unsigned>(AtomicOrdering::monotonic))
2093     return op.emitOpError()
2094            << "expected at least '"
2095            << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2096            << "' ordering";
2097 
2098   return success();
2099 }
2100 
2101 //===----------------------------------------------------------------------===//
2102 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
2103 //===----------------------------------------------------------------------===//
2104 
2105 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
2106   p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", "
2107     << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' '
2108     << stringifyAtomicOrdering(op.failure_ordering());
2109   p.printOptionalAttrDict(op->getAttrs(),
2110                           {"success_ordering", "failure_ordering"});
2111   p << " : " << op.val().getType();
2112 }
2113 
2114 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
2115 //                 keyword keyword attribute-dict? `:` type
2116 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
2117                                         OperationState &result) {
2118   auto &builder = parser.getBuilder();
2119   Type type;
2120   OpAsmParser::OperandType ptr, cmp, val;
2121   if (parser.parseOperand(ptr) || parser.parseComma() ||
2122       parser.parseOperand(cmp) || parser.parseComma() ||
2123       parser.parseOperand(val) ||
2124       parseAtomicOrdering(parser, result, "success_ordering") ||
2125       parseAtomicOrdering(parser, result, "failure_ordering") ||
2126       parser.parseOptionalAttrDict(result.attributes) ||
2127       parser.parseColonType(type) ||
2128       parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2129                             result.operands) ||
2130       parser.resolveOperand(cmp, type, result.operands) ||
2131       parser.resolveOperand(val, type, result.operands))
2132     return failure();
2133 
2134   auto boolType = IntegerType::get(builder.getContext(), 1);
2135   auto resultType =
2136       LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
2137   result.addTypes(resultType);
2138 
2139   return success();
2140 }
2141 
2142 static LogicalResult verify(AtomicCmpXchgOp op) {
2143   auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
2144   if (!ptrType)
2145     return op.emitOpError("expected LLVM IR pointer type for operand #0");
2146   auto cmpType = op.cmp().getType();
2147   auto valType = op.val().getType();
2148   if (cmpType != ptrType.getElementType() || cmpType != valType)
2149     return op.emitOpError("expected LLVM IR element type for operand #0 to "
2150                           "match type for all other operands");
2151   auto intType = valType.dyn_cast<IntegerType>();
2152   unsigned intBitWidth = intType ? intType.getWidth() : 0;
2153   if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
2154       intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
2155       !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
2156       !valType.isa<Float32Type>() && !valType.isa<Float64Type>())
2157     return op.emitOpError("unexpected LLVM IR type");
2158   if (op.success_ordering() < AtomicOrdering::monotonic ||
2159       op.failure_ordering() < AtomicOrdering::monotonic)
2160     return op.emitOpError("ordering must be at least 'monotonic'");
2161   if (op.failure_ordering() == AtomicOrdering::release ||
2162       op.failure_ordering() == AtomicOrdering::acq_rel)
2163     return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2164   return success();
2165 }
2166 
2167 //===----------------------------------------------------------------------===//
2168 // Printer, parser and verifier for LLVM::FenceOp.
2169 //===----------------------------------------------------------------------===//
2170 
2171 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
2172 // attribute-dict?
2173 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) {
2174   StringAttr sScope;
2175   StringRef syncscopeKeyword = "syncscope";
2176   if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
2177     if (parser.parseLParen() ||
2178         parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
2179         parser.parseRParen())
2180       return failure();
2181   } else {
2182     result.addAttribute(syncscopeKeyword,
2183                         parser.getBuilder().getStringAttr(""));
2184   }
2185   if (parseAtomicOrdering(parser, result, "ordering") ||
2186       parser.parseOptionalAttrDict(result.attributes))
2187     return failure();
2188   return success();
2189 }
2190 
2191 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
2192   StringRef syncscopeKeyword = "syncscope";
2193   p << op.getOperationName() << ' ';
2194   if (!op->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
2195     p << "syncscope(" << op->getAttr(syncscopeKeyword) << ") ";
2196   p << stringifyAtomicOrdering(op.ordering());
2197 }
2198 
2199 static LogicalResult verify(FenceOp &op) {
2200   if (op.ordering() == AtomicOrdering::not_atomic ||
2201       op.ordering() == AtomicOrdering::unordered ||
2202       op.ordering() == AtomicOrdering::monotonic)
2203     return op.emitOpError("can be given only acquire, release, acq_rel, "
2204                           "and seq_cst orderings");
2205   return success();
2206 }
2207 
2208 //===----------------------------------------------------------------------===//
2209 // LLVMDialect initialization, type parsing, and registration.
2210 //===----------------------------------------------------------------------===//
2211 
2212 void LLVMDialect::initialize() {
2213   addAttributes<FMFAttr, LoopOptionAttr>();
2214 
2215   // clang-format off
2216   addTypes<LLVMVoidType,
2217            LLVMPPCFP128Type,
2218            LLVMX86MMXType,
2219            LLVMTokenType,
2220            LLVMLabelType,
2221            LLVMMetadataType,
2222            LLVMFunctionType,
2223            LLVMPointerType,
2224            LLVMFixedVectorType,
2225            LLVMScalableVectorType,
2226            LLVMArrayType,
2227            LLVMStructType>();
2228   // clang-format on
2229   addOperations<
2230 #define GET_OP_LIST
2231 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2232       >();
2233 
2234   // Support unknown operations because not all LLVM operations are registered.
2235   allowUnknownOperations();
2236 }
2237 
2238 #define GET_OP_CLASSES
2239 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2240 
2241 /// Parse a type registered to this dialect.
2242 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
2243   return detail::parseType(parser);
2244 }
2245 
2246 /// Print a type registered to this dialect.
2247 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
2248   return detail::printType(type, os);
2249 }
2250 
2251 LogicalResult LLVMDialect::verifyDataLayoutString(
2252     StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
2253   llvm::Expected<llvm::DataLayout> maybeDataLayout =
2254       llvm::DataLayout::parse(descr);
2255   if (maybeDataLayout)
2256     return success();
2257 
2258   std::string message;
2259   llvm::raw_string_ostream messageStream(message);
2260   llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
2261   reportError("invalid data layout descriptor: " + messageStream.str());
2262   return failure();
2263 }
2264 
2265 /// Verify LLVM dialect attributes.
2266 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
2267                                                     NamedAttribute attr) {
2268   // If the `llvm.loop` attribute is present, enforce the following structure,
2269   // which the module translation can assume.
2270   if (attr.first.strref() == LLVMDialect::getLoopAttrName()) {
2271     auto loopAttr = attr.second.dyn_cast<DictionaryAttr>();
2272     if (!loopAttr)
2273       return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
2274                                << "' to be a dictionary attribute";
2275     Optional<NamedAttribute> parallelAccessGroup =
2276         loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
2277     if (parallelAccessGroup.hasValue()) {
2278       auto accessGroups = parallelAccessGroup->second.dyn_cast<ArrayAttr>();
2279       if (!accessGroups)
2280         return op->emitOpError()
2281                << "expected '" << LLVMDialect::getParallelAccessAttrName()
2282                << "' to be an array attribute";
2283       for (Attribute attr : accessGroups) {
2284         auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
2285         if (!accessGroupRef)
2286           return op->emitOpError()
2287                  << "expected '" << attr << "' to be a symbol reference";
2288         StringRef metadataName = accessGroupRef.getRootReference();
2289         auto metadataOp =
2290             SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
2291                 op->getParentOp(), metadataName);
2292         if (!metadataOp)
2293           return op->emitOpError()
2294                  << "expected '" << attr << "' to reference a metadata op";
2295         StringRef accessGroupName = accessGroupRef.getLeafReference();
2296         Operation *accessGroupOp =
2297             SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
2298         if (!accessGroupOp)
2299           return op->emitOpError()
2300                  << "expected '" << attr << "' to reference an access_group op";
2301       }
2302     }
2303 
2304     Optional<NamedAttribute> loopOptions =
2305         loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
2306     if (loopOptions.hasValue()) {
2307       auto options = loopOptions->second.dyn_cast<ArrayAttr>();
2308       if (!options)
2309         return op->emitOpError()
2310                << "expected '" << LLVMDialect::getLoopOptionsAttrName()
2311                << "' to be an array attribute";
2312       if (!llvm::all_of(options, [](Attribute option) {
2313             return option.isa<LoopOptionAttr>();
2314           }))
2315         return op->emitOpError() << "invalid loop options list " << options;
2316     }
2317   }
2318 
2319   // If the data layout attribute is present, it must use the LLVM data layout
2320   // syntax. Try parsing it and report errors in case of failure. Users of this
2321   // attribute may assume it is well-formed and can pass it to the (asserting)
2322   // llvm::DataLayout constructor.
2323   if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName())
2324     return success();
2325   if (auto stringAttr = attr.second.dyn_cast<StringAttr>())
2326     return verifyDataLayoutString(
2327         stringAttr.getValue(),
2328         [op](const Twine &message) { op->emitOpError() << message.str(); });
2329 
2330   return op->emitOpError() << "expected '"
2331                            << LLVM::LLVMDialect::getDataLayoutAttrName()
2332                            << "' to be a string attribute";
2333 }
2334 
2335 /// Verify LLVMIR function argument attributes.
2336 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
2337                                                     unsigned regionIdx,
2338                                                     unsigned argIdx,
2339                                                     NamedAttribute argAttr) {
2340   // Check that llvm.noalias is a boolean attribute.
2341   if (argAttr.first == LLVMDialect::getNoAliasAttrName() &&
2342       !argAttr.second.isa<BoolAttr>())
2343     return op->emitError()
2344            << "llvm.noalias argument attribute of non boolean type";
2345   // Check that llvm.align is an integer attribute.
2346   if (argAttr.first == LLVMDialect::getAlignAttrName() &&
2347       !argAttr.second.isa<IntegerAttr>())
2348     return op->emitError()
2349            << "llvm.align argument attribute of non integer type";
2350   return success();
2351 }
2352 
2353 //===----------------------------------------------------------------------===//
2354 // Utility functions.
2355 //===----------------------------------------------------------------------===//
2356 
2357 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
2358                                      StringRef name, StringRef value,
2359                                      LLVM::Linkage linkage) {
2360   assert(builder.getInsertionBlock() &&
2361          builder.getInsertionBlock()->getParentOp() &&
2362          "expected builder to point to a block constrained in an op");
2363   auto module =
2364       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
2365   assert(module && "builder points to an op outside of a module");
2366 
2367   // Create the global at the entry of the module.
2368   OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
2369   MLIRContext *ctx = builder.getContext();
2370   auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
2371   auto global = moduleBuilder.create<LLVM::GlobalOp>(
2372       loc, type, /*isConstant=*/true, linkage, name,
2373       builder.getStringAttr(value));
2374 
2375   // Get the pointer to the first character in the global string.
2376   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
2377   Value cst0 = builder.create<LLVM::ConstantOp>(
2378       loc, IntegerType::get(ctx, 64),
2379       builder.getIntegerAttr(builder.getIndexType(), 0));
2380   return builder.create<LLVM::GEPOp>(
2381       loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
2382       ValueRange{cst0, cst0});
2383 }
2384 
2385 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
2386   return op->hasTrait<OpTrait::SymbolTable>() &&
2387          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
2388 }
2389 
2390 FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) {
2391   return Base::get(context, static_cast<uint64_t>(flags));
2392 }
2393 
2394 FastmathFlags FMFAttr::getFlags() const {
2395   return static_cast<FastmathFlags>(getImpl()->value);
2396 }
2397 
2398 static constexpr const FastmathFlags FastmathFlagsList[] = {
2399     // clang-format off
2400     FastmathFlags::nnan,
2401     FastmathFlags::ninf,
2402     FastmathFlags::nsz,
2403     FastmathFlags::arcp,
2404     FastmathFlags::contract,
2405     FastmathFlags::afn,
2406     FastmathFlags::reassoc,
2407     FastmathFlags::fast,
2408     // clang-format on
2409 };
2410 
2411 void FMFAttr::print(DialectAsmPrinter &printer) const {
2412   printer << "fastmath<";
2413   auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
2414     return bitEnumContains(this->getFlags(), flag);
2415   });
2416   llvm::interleaveComma(flags, printer,
2417                         [&](auto flag) { printer << stringifyEnum(flag); });
2418   printer << ">";
2419 }
2420 
2421 Attribute FMFAttr::parse(DialectAsmParser &parser) {
2422   if (failed(parser.parseLess()))
2423     return {};
2424 
2425   FastmathFlags flags = {};
2426   if (failed(parser.parseOptionalGreater())) {
2427     do {
2428       StringRef elemName;
2429       if (failed(parser.parseKeyword(&elemName)))
2430         return {};
2431 
2432       auto elem = symbolizeFastmathFlags(elemName);
2433       if (!elem) {
2434         parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
2435             << elemName;
2436         return {};
2437       }
2438 
2439       flags = flags | *elem;
2440     } while (succeeded(parser.parseOptionalComma()));
2441 
2442     if (failed(parser.parseGreater()))
2443       return {};
2444   }
2445 
2446   return FMFAttr::get(flags, parser.getBuilder().getContext());
2447 }
2448 
2449 LoopOptionAttr LoopOptionAttr::getDisableUnroll(MLIRContext *context,
2450                                                 bool disable) {
2451   auto option = LoopOptionCase::disable_unroll;
2452   return Base::get(context, static_cast<uint64_t>(option),
2453                    static_cast<int32_t>(disable));
2454 }
2455 
2456 LoopOptionAttr LoopOptionAttr::getDisableLICM(MLIRContext *context,
2457                                               bool disable) {
2458   auto option = LoopOptionCase::disable_licm;
2459   return Base::get(context, static_cast<uint64_t>(option),
2460                    static_cast<int32_t>(disable));
2461 }
2462 
2463 LoopOptionAttr LoopOptionAttr::getInterleaveCount(MLIRContext *context,
2464                                                   int32_t count) {
2465   auto option = LoopOptionCase::interleave_count;
2466   return Base::get(context, static_cast<uint64_t>(option),
2467                    static_cast<int32_t>(count));
2468 }
2469 
2470 LoopOptionCase LoopOptionAttr::getCase() const {
2471   return static_cast<LoopOptionCase>(getImpl()->option);
2472 }
2473 
2474 bool LoopOptionAttr::getBool() const {
2475   LoopOptionCase option = getCase();
2476   (void)option;
2477   assert(option == LoopOptionCase::disable_licm ||
2478          option == LoopOptionCase::disable_unroll &&
2479              "expected a boolean loop option");
2480   return static_cast<bool>(getImpl()->value);
2481 }
2482 
2483 int32_t LoopOptionAttr::getInt() const {
2484   LoopOptionCase option = getCase();
2485   (void)option;
2486   assert(option == LoopOptionCase::interleave_count &&
2487          "expected an integer loop option");
2488   return getImpl()->value;
2489 }
2490 
2491 void LoopOptionAttr::print(DialectAsmPrinter &printer) const {
2492   printer << "loopopt<" << stringifyEnum(getCase()) << " = ";
2493   switch (getCase()) {
2494   case LoopOptionCase::disable_licm:
2495   case LoopOptionCase::disable_unroll:
2496     printer << (getBool() ? "true" : "false");
2497     break;
2498   case LoopOptionCase::interleave_count:
2499     printer << getInt();
2500     break;
2501   }
2502   printer << ">";
2503 }
2504 
2505 Attribute LoopOptionAttr::parse(DialectAsmParser &parser) {
2506   if (failed(parser.parseLess()))
2507     return {};
2508 
2509   StringRef optionName;
2510   if (failed(parser.parseKeyword(&optionName)))
2511     return {};
2512 
2513   auto option = symbolizeLoopOptionCase(optionName);
2514   if (!option) {
2515     parser.emitError(parser.getNameLoc(), "unknown loop option: ")
2516         << optionName;
2517     return {};
2518   }
2519 
2520   if (failed(parser.parseEqual()))
2521     return {};
2522 
2523   int32_t value;
2524   switch (*option) {
2525   case LoopOptionCase::disable_licm:
2526   case LoopOptionCase::disable_unroll:
2527     if (succeeded(parser.parseOptionalKeyword("true")))
2528       value = 1;
2529     else if (succeeded(parser.parseOptionalKeyword("false")))
2530       value = 0;
2531     else {
2532       parser.emitError(parser.getNameLoc(),
2533                        "expected boolean value 'true' or 'false'");
2534       return {};
2535     }
2536     break;
2537   case LoopOptionCase::interleave_count:
2538     if (failed(parser.parseInteger(value))) {
2539       parser.emitError(parser.getNameLoc(), "expected integer value");
2540       return {};
2541     }
2542     break;
2543   }
2544 
2545   if (failed(parser.parseGreater()))
2546     return {};
2547 
2548   return Base::get(parser.getBuilder().getContext(),
2549                    static_cast<uint64_t>(*option), value);
2550 }
2551 
2552 Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
2553                                       Type type) const {
2554   if (type) {
2555     parser.emitError(parser.getNameLoc(), "unexpected type");
2556     return {};
2557   }
2558   StringRef attrKind;
2559   if (parser.parseKeyword(&attrKind))
2560     return {};
2561 
2562   if (attrKind == "fastmath")
2563     return FMFAttr::parse(parser);
2564 
2565   if (attrKind == "loopopt")
2566     return LoopOptionAttr::parse(parser);
2567 
2568   parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
2569   return {};
2570 }
2571 
2572 void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
2573   if (auto fmf = attr.dyn_cast<FMFAttr>())
2574     fmf.print(os);
2575   else if (auto lopt = attr.dyn_cast<LoopOptionAttr>())
2576     lopt.print(os);
2577   else
2578     llvm_unreachable("Unknown attribute type");
2579 }
2580