1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "mlir/IR/FunctionImplementation.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/IR/Module.h"
19 #include "mlir/IR/StandardTypes.h"
20 
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/AsmParser/Parser.h"
23 #include "llvm/IR/Attributes.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/Support/Mutex.h"
27 #include "llvm/Support/SourceMgr.h"
28 
29 using namespace mlir;
30 using namespace mlir::LLVM;
31 
32 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // Printing/parsing for LLVM::CmpOp.
36 //===----------------------------------------------------------------------===//
37 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
38   p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
39     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
40   p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
41   p << " : " << op.lhs().getType();
42 }
43 
44 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
45   p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
46     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
47   p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
48   p << " : " << op.lhs().getType();
49 }
50 
51 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
52 //                 attribute-dict? `:` type
53 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
54 //                 attribute-dict? `:` type
55 template <typename CmpPredicateType>
56 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
57   Builder &builder = parser.getBuilder();
58 
59   StringAttr predicateAttr;
60   OpAsmParser::OperandType lhs, rhs;
61   Type type;
62   llvm::SMLoc predicateLoc, trailingTypeLoc;
63   if (parser.getCurrentLocation(&predicateLoc) ||
64       parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
65       parser.parseOperand(lhs) || parser.parseComma() ||
66       parser.parseOperand(rhs) ||
67       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
68       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
69       parser.resolveOperand(lhs, type, result.operands) ||
70       parser.resolveOperand(rhs, type, result.operands))
71     return failure();
72 
73   // Replace the string attribute `predicate` with an integer attribute.
74   int64_t predicateValue = 0;
75   if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
76     Optional<ICmpPredicate> predicate =
77         symbolizeICmpPredicate(predicateAttr.getValue());
78     if (!predicate)
79       return parser.emitError(predicateLoc)
80              << "'" << predicateAttr.getValue()
81              << "' is an incorrect value of the 'predicate' attribute";
82     predicateValue = static_cast<int64_t>(predicate.getValue());
83   } else {
84     Optional<FCmpPredicate> predicate =
85         symbolizeFCmpPredicate(predicateAttr.getValue());
86     if (!predicate)
87       return parser.emitError(predicateLoc)
88              << "'" << predicateAttr.getValue()
89              << "' is an incorrect value of the 'predicate' attribute";
90     predicateValue = static_cast<int64_t>(predicate.getValue());
91   }
92 
93   result.attributes[0].second =
94       parser.getBuilder().getI64IntegerAttr(predicateValue);
95 
96   // The result type is either i1 or a vector type <? x i1> if the inputs are
97   // vectors.
98   auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
99   auto resultType = LLVMType::getInt1Ty(dialect);
100   auto argType = type.dyn_cast<LLVM::LLVMType>();
101   if (!argType)
102     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
103   if (argType.getUnderlyingType()->isVectorTy())
104     resultType = LLVMType::getVectorTy(
105         resultType, argType.getUnderlyingType()->getVectorNumElements());
106 
107   result.addTypes({resultType});
108   return success();
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Printing/parsing for LLVM::AllocaOp.
113 //===----------------------------------------------------------------------===//
114 
115 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
116   auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
117 
118   auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
119                                   op.getContext());
120 
121   p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
122   if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
123     p.printOptionalAttrDict(op.getAttrs());
124   else
125     p.printOptionalAttrDict(op.getAttrs(), {"alignment"});
126   p << " : " << funcTy;
127 }
128 
129 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
130 //                 `:` type `,` type
131 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
132   OpAsmParser::OperandType arraySize;
133   Type type, elemType;
134   llvm::SMLoc trailingTypeLoc;
135   if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
136       parser.parseType(elemType) ||
137       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
138       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
139     return failure();
140 
141   // Extract the result type from the trailing function type.
142   auto funcType = type.dyn_cast<FunctionType>();
143   if (!funcType || funcType.getNumInputs() != 1 ||
144       funcType.getNumResults() != 1)
145     return parser.emitError(
146         trailingTypeLoc,
147         "expected trailing function type with one argument and one result");
148 
149   if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
150     return failure();
151 
152   result.addTypes({funcType.getResult(0)});
153   return success();
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // LLVM::BrOp
158 //===----------------------------------------------------------------------===//
159 
160 Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
161   assert(index == 0 && "invalid successor index");
162   return getOperands();
163 }
164 
165 bool BrOp::canEraseSuccessorOperand() { return true; }
166 
167 //===----------------------------------------------------------------------===//
168 // LLVM::CondBrOp
169 //===----------------------------------------------------------------------===//
170 
171 Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
172   assert(index < getNumSuccessors() && "invalid successor index");
173   return index == 0 ? trueDestOperands() : falseDestOperands();
174 }
175 
176 bool CondBrOp::canEraseSuccessorOperand() { return true; }
177 
178 //===----------------------------------------------------------------------===//
179 // Printing/parsing for LLVM::LoadOp.
180 //===----------------------------------------------------------------------===//
181 
182 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
183   p << op.getOperationName() << ' ' << op.addr();
184   p.printOptionalAttrDict(op.getAttrs());
185   p << " : " << op.addr().getType();
186 }
187 
188 // Extract the pointee type from the LLVM pointer type wrapped in MLIR.  Return
189 // the resulting type wrapped in MLIR, or nullptr on error.
190 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
191                                     llvm::SMLoc trailingTypeLoc) {
192   auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
193   if (!llvmTy)
194     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
195            nullptr;
196   if (!llvmTy.getUnderlyingType()->isPointerTy())
197     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
198            nullptr;
199   return llvmTy.getPointerElementTy();
200 }
201 
202 // <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
203 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
204   OpAsmParser::OperandType addr;
205   Type type;
206   llvm::SMLoc trailingTypeLoc;
207 
208   if (parser.parseOperand(addr) ||
209       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
210       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
211       parser.resolveOperand(addr, type, result.operands))
212     return failure();
213 
214   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
215 
216   result.addTypes(elemTy);
217   return success();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Printing/parsing for LLVM::StoreOp.
222 //===----------------------------------------------------------------------===//
223 
224 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
225   p << op.getOperationName() << ' ' << op.value() << ", " << op.addr();
226   p.printOptionalAttrDict(op.getAttrs());
227   p << " : " << op.addr().getType();
228 }
229 
230 // <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
231 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
232   OpAsmParser::OperandType addr, value;
233   Type type;
234   llvm::SMLoc trailingTypeLoc;
235 
236   if (parser.parseOperand(value) || parser.parseComma() ||
237       parser.parseOperand(addr) ||
238       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
239       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
240     return failure();
241 
242   Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
243   if (!elemTy)
244     return failure();
245 
246   if (parser.resolveOperand(value, elemTy, result.operands) ||
247       parser.resolveOperand(addr, type, result.operands))
248     return failure();
249 
250   return success();
251 }
252 
253 ///===---------------------------------------------------------------------===//
254 /// LLVM::InvokeOp
255 ///===---------------------------------------------------------------------===//
256 
257 Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
258   assert(index < getNumSuccessors() && "invalid successor index");
259   return index == 0 ? normalDestOperands() : unwindDestOperands();
260 }
261 
262 bool InvokeOp::canEraseSuccessorOperand() { return true; }
263 
264 static LogicalResult verify(InvokeOp op) {
265   if (op.getNumResults() > 1)
266     return op.emitOpError("must have 0 or 1 result");
267 
268   Block *unwindDest = op.unwindDest();
269   if (unwindDest->empty())
270     return op.emitError(
271         "must have at least one operation in unwind destination");
272 
273   // In unwind destination, first operation must be LandingpadOp
274   if (!isa<LandingpadOp>(unwindDest->front()))
275     return op.emitError("first operation in unwind destination should be a "
276                         "llvm.landingpad operation");
277 
278   return success();
279 }
280 
281 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) {
282   auto callee = op.callee();
283   bool isDirect = callee.hasValue();
284 
285   p << op.getOperationName() << ' ';
286 
287   // Either function name or pointer
288   if (isDirect)
289     p.printSymbolName(callee.getValue());
290   else
291     p << op.getOperand(0);
292 
293   p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
294   p << " to ";
295   p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands());
296   p << " unwind ";
297   p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands());
298 
299   p.printOptionalAttrDict(op.getAttrs(),
300                           {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
301   p << " : ";
302   p.printFunctionalType(
303       llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1),
304       op.getResultTypes());
305 }
306 
307 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
308 ///                  `to` bb-id (`[` ssa-use-and-type-list `]`)?
309 ///                  `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
310 ///                  attribute-dict? `:` function-type
311 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
312   SmallVector<OpAsmParser::OperandType, 8> operands;
313   FunctionType funcType;
314   SymbolRefAttr funcAttr;
315   llvm::SMLoc trailingTypeLoc;
316   Block *normalDest, *unwindDest;
317   SmallVector<Value, 4> normalOperands, unwindOperands;
318   Builder &builder = parser.getBuilder();
319 
320   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
321   // case of an indirect call, there will be 1 operand before `(`.  In case of a
322   // direct call, there will be no operands and the parser will stop at the
323   // function identifier without complaining.
324   if (parser.parseOperandList(operands))
325     return failure();
326   bool isDirect = operands.empty();
327 
328   // Optionally parse a function identifier.
329   if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
330     return failure();
331 
332   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
333       parser.parseKeyword("to") ||
334       parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
335       parser.parseKeyword("unwind") ||
336       parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
337       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
338       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
339     return failure();
340 
341   if (isDirect) {
342     // Make sure types match.
343     if (parser.resolveOperands(operands, funcType.getInputs(),
344                                parser.getNameLoc(), result.operands))
345       return failure();
346     result.addTypes(funcType.getResults());
347   } else {
348     // Construct the LLVM IR Dialect function type that the first operand
349     // should match.
350     if (funcType.getNumResults() > 1)
351       return parser.emitError(trailingTypeLoc,
352                               "expected function with 0 or 1 result");
353 
354     auto *llvmDialect =
355         builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
356     LLVM::LLVMType llvmResultType;
357     if (funcType.getNumResults() == 0) {
358       llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
359     } else {
360       llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
361       if (!llvmResultType)
362         return parser.emitError(trailingTypeLoc,
363                                 "expected result to have LLVM type");
364     }
365 
366     SmallVector<LLVM::LLVMType, 8> argTypes;
367     argTypes.reserve(funcType.getNumInputs());
368     for (Type ty : funcType.getInputs()) {
369       if (auto argType = ty.dyn_cast<LLVM::LLVMType>())
370         argTypes.push_back(argType);
371       else
372         return parser.emitError(trailingTypeLoc,
373                                 "expected LLVM types as inputs");
374     }
375 
376     auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
377                                                       /*isVarArg=*/false);
378     auto wrappedFuncType = llvmFuncType.getPointerTo();
379 
380     auto funcArguments = llvm::makeArrayRef(operands).drop_front();
381 
382     // Make sure that the first operand (indirect callee) matches the wrapped
383     // LLVM IR function type, and that the types of the other call operands
384     // match the types of the function arguments.
385     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
386         parser.resolveOperands(funcArguments, funcType.getInputs(),
387                                parser.getNameLoc(), result.operands))
388       return failure();
389 
390     result.addTypes(llvmResultType);
391   }
392   result.addSuccessors({normalDest, unwindDest});
393   result.addOperands(normalOperands);
394   result.addOperands(unwindOperands);
395 
396   result.addAttribute(
397       InvokeOp::getOperandSegmentSizeAttr(),
398       builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
399                                 static_cast<int32_t>(normalOperands.size()),
400                                 static_cast<int32_t>(unwindOperands.size())}));
401   return success();
402 }
403 
404 ///===----------------------------------------------------------------------===//
405 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
406 ///===----------------------------------------------------------------------===//
407 
408 static LogicalResult verify(LandingpadOp op) {
409   Value value;
410   if (LLVMFuncOp func = op.getParentOfType<LLVMFuncOp>()) {
411     if (!func.personality().hasValue())
412       return op.emitError(
413           "llvm.landingpad needs to be in a function with a personality");
414   }
415 
416   if (!op.cleanup() && op.getOperands().empty())
417     return op.emitError("landingpad instruction expects at least one clause or "
418                         "cleanup attribute");
419 
420   for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
421     value = op.getOperand(idx);
422     bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
423     if (isFilter) {
424       // FIXME: Verify filter clauses when arrays are appropriately handled
425     } else {
426       // catch - global addresses only.
427       // Bitcast ops should have global addresses as their args.
428       if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) {
429         if (auto addrOp =
430                 dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp()))
431           continue;
432         return op.emitError("constant clauses expected")
433                    .attachNote(bcOp.getLoc())
434                << "global addresses expected as operand to "
435                   "bitcast used in clauses for landingpad";
436       }
437       // NullOp and AddressOfOp allowed
438       if (dyn_cast_or_null<NullOp>(value.getDefiningOp()))
439         continue;
440       if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp()))
441         continue;
442       return op.emitError("clause #")
443              << idx << " is not a known constant - null, addressof, bitcast";
444     }
445   }
446   return success();
447 }
448 
449 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
450   p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");
451 
452   // Clauses
453   for (auto value : op.getOperands()) {
454     // Similar to llvm - if clause is an array type then it is filter
455     // clause else catch clause
456     bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
457     p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
458       << value.getType() << ") ";
459   }
460 
461   p.printOptionalAttrDict(op.getAttrs(), {"cleanup"});
462 
463   p << ": " << op.getType();
464 }
465 
466 /// <operation> ::= `llvm.landingpad` `cleanup`?
467 ///                 ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
468 static ParseResult parseLandingpadOp(OpAsmParser &parser,
469                                      OperationState &result) {
470   // Check for cleanup
471   if (succeeded(parser.parseOptionalKeyword("cleanup")))
472     result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
473 
474   // Parse clauses with types
475   while (succeeded(parser.parseOptionalLParen()) &&
476          (succeeded(parser.parseOptionalKeyword("filter")) ||
477           succeeded(parser.parseOptionalKeyword("catch")))) {
478     OpAsmParser::OperandType operand;
479     Type ty;
480     if (parser.parseOperand(operand) || parser.parseColon() ||
481         parser.parseType(ty) ||
482         parser.resolveOperand(operand, ty, result.operands) ||
483         parser.parseRParen())
484       return failure();
485   }
486 
487   Type type;
488   if (parser.parseColon() || parser.parseType(type))
489     return failure();
490 
491   result.addTypes(type);
492   return success();
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // Printing/parsing for LLVM::CallOp.
497 //===----------------------------------------------------------------------===//
498 
499 static void printCallOp(OpAsmPrinter &p, CallOp &op) {
500   auto callee = op.callee();
501   bool isDirect = callee.hasValue();
502 
503   // Print the direct callee if present as a function attribute, or an indirect
504   // callee (first operand) otherwise.
505   p << op.getOperationName() << ' ';
506   if (isDirect)
507     p.printSymbolName(callee.getValue());
508   else
509     p << op.getOperand(0);
510 
511   p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
512   p.printOptionalAttrDict(op.getAttrs(), {"callee"});
513 
514   // Reconstruct the function MLIR function type from operand and result types.
515   SmallVector<Type, 8> argTypes(
516       llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
517 
518   p << " : "
519     << FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
520 }
521 
522 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
523 //                 attribute-dict? `:` function-type
524 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
525   SmallVector<OpAsmParser::OperandType, 8> operands;
526   Type type;
527   SymbolRefAttr funcAttr;
528   llvm::SMLoc trailingTypeLoc;
529 
530   // Parse an operand list that will, in practice, contain 0 or 1 operand.  In
531   // case of an indirect call, there will be 1 operand before `(`.  In case of a
532   // direct call, there will be no operands and the parser will stop at the
533   // function identifier without complaining.
534   if (parser.parseOperandList(operands))
535     return failure();
536   bool isDirect = operands.empty();
537 
538   // Optionally parse a function identifier.
539   if (isDirect)
540     if (parser.parseAttribute(funcAttr, "callee", result.attributes))
541       return failure();
542 
543   if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
544       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
545       parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
546     return failure();
547 
548   auto funcType = type.dyn_cast<FunctionType>();
549   if (!funcType)
550     return parser.emitError(trailingTypeLoc, "expected function type");
551   if (isDirect) {
552     // Make sure types match.
553     if (parser.resolveOperands(operands, funcType.getInputs(),
554                                parser.getNameLoc(), result.operands))
555       return failure();
556     result.addTypes(funcType.getResults());
557   } else {
558     // Construct the LLVM IR Dialect function type that the first operand
559     // should match.
560     if (funcType.getNumResults() > 1)
561       return parser.emitError(trailingTypeLoc,
562                               "expected function with 0 or 1 result");
563 
564     Builder &builder = parser.getBuilder();
565     auto *llvmDialect =
566         builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
567     LLVM::LLVMType llvmResultType;
568     if (funcType.getNumResults() == 0) {
569       llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
570     } else {
571       llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
572       if (!llvmResultType)
573         return parser.emitError(trailingTypeLoc,
574                                 "expected result to have LLVM type");
575     }
576 
577     SmallVector<LLVM::LLVMType, 8> argTypes;
578     argTypes.reserve(funcType.getNumInputs());
579     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
580       auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
581       if (!argType)
582         return parser.emitError(trailingTypeLoc,
583                                 "expected LLVM types as inputs");
584       argTypes.push_back(argType);
585     }
586     auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
587                                                       /*isVarArg=*/false);
588     auto wrappedFuncType = llvmFuncType.getPointerTo();
589 
590     auto funcArguments =
591         ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
592 
593     // Make sure that the first operand (indirect callee) matches the wrapped
594     // LLVM IR function type, and that the types of the other call operands
595     // match the types of the function arguments.
596     if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
597         parser.resolveOperands(funcArguments, funcType.getInputs(),
598                                parser.getNameLoc(), result.operands))
599       return failure();
600 
601     result.addTypes(llvmResultType);
602   }
603 
604   return success();
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // Printing/parsing for LLVM::ExtractElementOp.
609 //===----------------------------------------------------------------------===//
610 // Expects vector to be of wrapped LLVM vector type and position to be of
611 // wrapped LLVM i32 type.
612 void LLVM::ExtractElementOp::build(Builder *b, OperationState &result,
613                                    Value vector, Value position,
614                                    ArrayRef<NamedAttribute> attrs) {
615   auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
616   auto llvmType = wrappedVectorType.getVectorElementType();
617   build(b, result, llvmType, vector, position);
618   result.addAttributes(attrs);
619 }
620 
621 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
622   p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
623     << " : " << op.position().getType() << "]";
624   p.printOptionalAttrDict(op.getAttrs());
625   p << " : " << op.vector().getType();
626 }
627 
628 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
629 //                 attribute-dict? `:` type
630 static ParseResult parseExtractElementOp(OpAsmParser &parser,
631                                          OperationState &result) {
632   llvm::SMLoc loc;
633   OpAsmParser::OperandType vector, position;
634   Type type, positionType;
635   if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
636       parser.parseLSquare() || parser.parseOperand(position) ||
637       parser.parseColonType(positionType) || parser.parseRSquare() ||
638       parser.parseOptionalAttrDict(result.attributes) ||
639       parser.parseColonType(type) ||
640       parser.resolveOperand(vector, type, result.operands) ||
641       parser.resolveOperand(position, positionType, result.operands))
642     return failure();
643   auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
644   if (!wrappedVectorType ||
645       !wrappedVectorType.getUnderlyingType()->isVectorTy())
646     return parser.emitError(
647         loc, "expected LLVM IR dialect vector type for operand #1");
648   result.addTypes(wrappedVectorType.getVectorElementType());
649   return success();
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // Printing/parsing for LLVM::ExtractValueOp.
654 //===----------------------------------------------------------------------===//
655 
656 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
657   p << op.getOperationName() << ' ' << op.container() << op.position();
658   p.printOptionalAttrDict(op.getAttrs(), {"position"});
659   p << " : " << op.container().getType();
660 }
661 
662 // Extract the type at `position` in the wrapped LLVM IR aggregate type
663 // `containerType`.  Position is an integer array attribute where each value
664 // is a zero-based position of the element in the aggregate type.  Return the
665 // resulting type wrapped in MLIR, or nullptr on error.
666 static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
667                                                        Type containerType,
668                                                        ArrayAttr positionAttr,
669                                                        llvm::SMLoc attributeLoc,
670                                                        llvm::SMLoc typeLoc) {
671   auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
672   if (!wrappedContainerType)
673     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
674 
675   // Infer the element type from the structure type: iteratively step inside the
676   // type by taking the element type, indexed by the position attribute for
677   // structures.  Check the position index before accessing, it is supposed to
678   // be in bounds.
679   for (Attribute subAttr : positionAttr) {
680     auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
681     if (!positionElementAttr)
682       return parser.emitError(attributeLoc,
683                               "expected an array of integer literals"),
684              nullptr;
685     int position = positionElementAttr.getInt();
686     auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
687     if (llvmContainerType->isArrayTy()) {
688       if (position < 0 || static_cast<unsigned>(position) >=
689                               llvmContainerType->getArrayNumElements())
690         return parser.emitError(attributeLoc, "position out of bounds"),
691                nullptr;
692       wrappedContainerType = wrappedContainerType.getArrayElementType();
693     } else if (llvmContainerType->isStructTy()) {
694       if (position < 0 || static_cast<unsigned>(position) >=
695                               llvmContainerType->getStructNumElements())
696         return parser.emitError(attributeLoc, "position out of bounds"),
697                nullptr;
698       wrappedContainerType =
699           wrappedContainerType.getStructElementType(position);
700     } else {
701       return parser.emitError(typeLoc,
702                               "expected wrapped LLVM IR structure/array type"),
703              nullptr;
704     }
705   }
706   return wrappedContainerType;
707 }
708 
709 // <operation> ::= `llvm.extractvalue` ssa-use
710 //                 `[` integer-literal (`,` integer-literal)* `]`
711 //                 attribute-dict? `:` type
712 static ParseResult parseExtractValueOp(OpAsmParser &parser,
713                                        OperationState &result) {
714   OpAsmParser::OperandType container;
715   Type containerType;
716   ArrayAttr positionAttr;
717   llvm::SMLoc attributeLoc, trailingTypeLoc;
718 
719   if (parser.parseOperand(container) ||
720       parser.getCurrentLocation(&attributeLoc) ||
721       parser.parseAttribute(positionAttr, "position", result.attributes) ||
722       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
723       parser.getCurrentLocation(&trailingTypeLoc) ||
724       parser.parseType(containerType) ||
725       parser.resolveOperand(container, containerType, result.operands))
726     return failure();
727 
728   auto elementType = getInsertExtractValueElementType(
729       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
730   if (!elementType)
731     return failure();
732 
733   result.addTypes(elementType);
734   return success();
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // Printing/parsing for LLVM::InsertElementOp.
739 //===----------------------------------------------------------------------===//
740 
741 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
742   p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
743     << op.position() << " : " << op.position().getType() << "]";
744   p.printOptionalAttrDict(op.getAttrs());
745   p << " : " << op.vector().getType();
746 }
747 
748 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
749 //                 attribute-dict? `:` type
750 static ParseResult parseInsertElementOp(OpAsmParser &parser,
751                                         OperationState &result) {
752   llvm::SMLoc loc;
753   OpAsmParser::OperandType vector, value, position;
754   Type vectorType, positionType;
755   if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
756       parser.parseComma() || parser.parseOperand(vector) ||
757       parser.parseLSquare() || parser.parseOperand(position) ||
758       parser.parseColonType(positionType) || parser.parseRSquare() ||
759       parser.parseOptionalAttrDict(result.attributes) ||
760       parser.parseColonType(vectorType))
761     return failure();
762 
763   auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
764   if (!wrappedVectorType ||
765       !wrappedVectorType.getUnderlyingType()->isVectorTy())
766     return parser.emitError(
767         loc, "expected LLVM IR dialect vector type for operand #1");
768   auto valueType = wrappedVectorType.getVectorElementType();
769   if (!valueType)
770     return failure();
771 
772   if (parser.resolveOperand(vector, vectorType, result.operands) ||
773       parser.resolveOperand(value, valueType, result.operands) ||
774       parser.resolveOperand(position, positionType, result.operands))
775     return failure();
776 
777   result.addTypes(vectorType);
778   return success();
779 }
780 
781 //===----------------------------------------------------------------------===//
782 // Printing/parsing for LLVM::InsertValueOp.
783 //===----------------------------------------------------------------------===//
784 
785 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
786   p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
787     << op.position();
788   p.printOptionalAttrDict(op.getAttrs(), {"position"});
789   p << " : " << op.container().getType();
790 }
791 
792 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
793 //                 `[` integer-literal (`,` integer-literal)* `]`
794 //                 attribute-dict? `:` type
795 static ParseResult parseInsertValueOp(OpAsmParser &parser,
796                                       OperationState &result) {
797   OpAsmParser::OperandType container, value;
798   Type containerType;
799   ArrayAttr positionAttr;
800   llvm::SMLoc attributeLoc, trailingTypeLoc;
801 
802   if (parser.parseOperand(value) || parser.parseComma() ||
803       parser.parseOperand(container) ||
804       parser.getCurrentLocation(&attributeLoc) ||
805       parser.parseAttribute(positionAttr, "position", result.attributes) ||
806       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
807       parser.getCurrentLocation(&trailingTypeLoc) ||
808       parser.parseType(containerType))
809     return failure();
810 
811   auto valueType = getInsertExtractValueElementType(
812       parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
813   if (!valueType)
814     return failure();
815 
816   if (parser.resolveOperand(container, containerType, result.operands) ||
817       parser.resolveOperand(value, valueType, result.operands))
818     return failure();
819 
820   result.addTypes(containerType);
821   return success();
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // Printing/parsing for LLVM::ReturnOp.
826 //===----------------------------------------------------------------------===//
827 
828 static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) {
829   p << op.getOperationName();
830   p.printOptionalAttrDict(op.getAttrs());
831   assert(op.getNumOperands() <= 1);
832 
833   if (op.getNumOperands() == 0)
834     return;
835 
836   p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
837 }
838 
839 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
840 //                 type-list-no-parens
841 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
842   SmallVector<OpAsmParser::OperandType, 1> operands;
843   Type type;
844 
845   if (parser.parseOperandList(operands) ||
846       parser.parseOptionalAttrDict(result.attributes))
847     return failure();
848   if (operands.empty())
849     return success();
850 
851   if (parser.parseColonType(type) ||
852       parser.resolveOperand(operands[0], type, result.operands))
853     return failure();
854   return success();
855 }
856 
857 //===----------------------------------------------------------------------===//
858 // Verifier for LLVM::AddressOfOp.
859 //===----------------------------------------------------------------------===//
860 
861 GlobalOp AddressOfOp::getGlobal() {
862   Operation *module = getParentOp();
863   while (module && !satisfiesLLVMModule(module))
864     module = module->getParentOp();
865   assert(module && "unexpected operation outside of a module");
866   return dyn_cast_or_null<LLVM::GlobalOp>(
867       mlir::SymbolTable::lookupSymbolIn(module, global_name()));
868 }
869 
870 static LogicalResult verify(AddressOfOp op) {
871   auto global = op.getGlobal();
872   if (!global)
873     return op.emitOpError(
874         "must reference a global defined by 'llvm.mlir.global'");
875 
876   if (global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
877       op.getResult().getType())
878     return op.emitOpError(
879         "the type must be a pointer to the type of the referred global");
880 
881   return success();
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // Builder, printer and verifier for LLVM::GlobalOp.
886 //===----------------------------------------------------------------------===//
887 
888 /// Returns the name used for the linkage attribute. This *must* correspond to
889 /// the name of the attribute in ODS.
890 static StringRef getLinkageAttrName() { return "linkage"; }
891 
892 void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
893                      bool isConstant, Linkage linkage, StringRef name,
894                      Attribute value, unsigned addrSpace,
895                      ArrayRef<NamedAttribute> attrs) {
896   result.addAttribute(SymbolTable::getSymbolAttrName(),
897                       builder->getStringAttr(name));
898   result.addAttribute("type", TypeAttr::get(type));
899   if (isConstant)
900     result.addAttribute("constant", builder->getUnitAttr());
901   if (value)
902     result.addAttribute("value", value);
903   result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr(
904                                                 static_cast<int64_t>(linkage)));
905   if (addrSpace != 0)
906     result.addAttribute("addr_space", builder->getI32IntegerAttr(addrSpace));
907   result.attributes.append(attrs.begin(), attrs.end());
908   result.addRegion();
909 }
910 
911 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
912   p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
913   if (op.constant())
914     p << "constant ";
915   p.printSymbolName(op.sym_name());
916   p << '(';
917   if (auto value = op.getValueOrNull())
918     p.printAttribute(value);
919   p << ')';
920   p.printOptionalAttrDict(op.getAttrs(),
921                           {SymbolTable::getSymbolAttrName(), "type", "constant",
922                            "value", getLinkageAttrName()});
923 
924   // Print the trailing type unless it's a string global.
925   if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
926     return;
927   p << " : " << op.type();
928 
929   Region &initializer = op.getInitializerRegion();
930   if (!initializer.empty())
931     p.printRegion(initializer, /*printEntryBlockArgs=*/false);
932 }
933 
934 //===----------------------------------------------------------------------===//
935 // Verifier for LLVM::DialectCastOp.
936 //===----------------------------------------------------------------------===//
937 
938 static LogicalResult verify(DialectCastOp op) {
939   auto verifyMLIRCastType = [&op](Type type) -> LogicalResult {
940     if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
941       if (llvmType.isVectorTy())
942         llvmType = llvmType.getVectorElementType();
943       if (llvmType.isIntegerTy() || llvmType.isHalfTy() ||
944           llvmType.isFloatTy() || llvmType.isDoubleTy()) {
945         return success();
946       }
947       return op.emitOpError("type must be non-index integer types, float "
948                             "types, or vector of mentioned types.");
949     }
950     if (auto vectorType = type.dyn_cast<VectorType>()) {
951       if (vectorType.getShape().size() > 1)
952         return op.emitOpError("only 1-d vector is allowed");
953       type = vectorType.getElementType();
954     }
955     if (type.isSignlessIntOrFloat())
956       return success();
957     // Note that memrefs are not supported. We currently don't have a use case
958     // for it, but even if we do, there are challenges:
959     // * if we allow memrefs to cast from/to memref descriptors, then the
960     // semantics of the cast op depends on the implementation detail of the
961     // descriptor.
962     // * if we allow memrefs to cast from/to bare pointers, some users might
963     // alternatively want metadata that only present in the descriptor.
964     //
965     // TODO(timshen): re-evaluate the memref cast design when it's needed.
966     return op.emitOpError("type must be non-index integer types, float types, "
967                           "or vector of mentioned types.");
968   };
969   return failure(failed(verifyMLIRCastType(op.in().getType())) ||
970                  failed(verifyMLIRCastType(op.getType())));
971 }
972 
973 // Parses one of the keywords provided in the list `keywords` and returns the
974 // position of the parsed keyword in the list. If none of the keywords from the
975 // list is parsed, returns -1.
976 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
977                                            ArrayRef<StringRef> keywords) {
978   for (auto en : llvm::enumerate(keywords)) {
979     if (succeeded(parser.parseOptionalKeyword(en.value())))
980       return en.index();
981   }
982   return -1;
983 }
984 
985 namespace {
986 template <typename Ty> struct EnumTraits {};
987 
988 #define REGISTER_ENUM_TYPE(Ty)                                                 \
989   template <> struct EnumTraits<Ty> {                                          \
990     static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
991     static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
992   }
993 
994 REGISTER_ENUM_TYPE(Linkage);
995 } // end namespace
996 
997 template <typename EnumTy>
998 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
999                                             OperationState &result,
1000                                             StringRef name) {
1001   SmallVector<StringRef, 10> names;
1002   for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
1003     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1004 
1005   int index = parseOptionalKeywordAlternative(parser, names);
1006   if (index == -1)
1007     return failure();
1008   result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
1009   return success();
1010 }
1011 
1012 // operation ::= `llvm.mlir.global` linkage `constant`? `@` identifier
1013 //               `(` attribute? `)` attribute-list? (`:` type)? region?
1014 //
1015 // The type can be omitted for string attributes, in which case it will be
1016 // inferred from the value of the string as [strlen(value) x i8].
1017 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
1018   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1019                                                getLinkageAttrName())))
1020     return parser.emitError(parser.getCurrentLocation(), "expected linkage");
1021 
1022   if (succeeded(parser.parseOptionalKeyword("constant")))
1023     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1024 
1025   StringAttr name;
1026   if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1027                              result.attributes) ||
1028       parser.parseLParen())
1029     return failure();
1030 
1031   Attribute value;
1032   if (parser.parseOptionalRParen()) {
1033     if (parser.parseAttribute(value, "value", result.attributes) ||
1034         parser.parseRParen())
1035       return failure();
1036   }
1037 
1038   SmallVector<Type, 1> types;
1039   if (parser.parseOptionalAttrDict(result.attributes) ||
1040       parser.parseOptionalColonTypeList(types))
1041     return failure();
1042 
1043   if (types.size() > 1)
1044     return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1045 
1046   Region &initRegion = *result.addRegion();
1047   if (types.empty()) {
1048     if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1049       MLIRContext *context = parser.getBuilder().getContext();
1050       auto *dialect = context->getRegisteredDialect<LLVMDialect>();
1051       auto arrayType = LLVM::LLVMType::getArrayTy(
1052           LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size());
1053       types.push_back(arrayType);
1054     } else {
1055       return parser.emitError(parser.getNameLoc(),
1056                               "type can only be omitted for string globals");
1057     }
1058   } else if (parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1059                                         /*argTypes=*/{})) {
1060     return failure();
1061   }
1062 
1063   result.addAttribute("type", TypeAttr::get(types[0]));
1064   return success();
1065 }
1066 
1067 static LogicalResult verify(GlobalOp op) {
1068   if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
1069     return op.emitOpError(
1070         "expects type to be a valid element type for an LLVM pointer");
1071   if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp()))
1072     return op.emitOpError("must appear at the module level");
1073 
1074   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1075     auto type = op.getType();
1076     if (!type.getUnderlyingType()->isArrayTy() ||
1077         !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) ||
1078         type.getArrayNumElements() != strAttr.getValue().size())
1079       return op.emitOpError(
1080           "requires an i8 array type of the length equal to that of the string "
1081           "attribute");
1082   }
1083 
1084   if (Block *b = op.getInitializerBlock()) {
1085     ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1086     if (ret.operand_type_begin() == ret.operand_type_end())
1087       return op.emitOpError("initializer region cannot return void");
1088     if (*ret.operand_type_begin() != op.getType())
1089       return op.emitOpError("initializer region type ")
1090              << *ret.operand_type_begin() << " does not match global type "
1091              << op.getType();
1092 
1093     if (op.getValueOrNull())
1094       return op.emitOpError("cannot have both initializer value and region");
1095   }
1096   return success();
1097 }
1098 
1099 //===----------------------------------------------------------------------===//
1100 // Printing/parsing for LLVM::ShuffleVectorOp.
1101 //===----------------------------------------------------------------------===//
1102 // Expects vector to be of wrapped LLVM vector type and position to be of
1103 // wrapped LLVM i32 type.
1104 void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1,
1105                                   Value v2, ArrayAttr mask,
1106                                   ArrayRef<NamedAttribute> attrs) {
1107   auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
1108   auto vType = LLVMType::getVectorTy(
1109       wrappedContainerType1.getVectorElementType(), mask.size());
1110   build(b, result, vType, v1, v2, mask);
1111   result.addAttributes(attrs);
1112 }
1113 
1114 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
1115   p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
1116     << op.mask();
1117   p.printOptionalAttrDict(op.getAttrs(), {"mask"});
1118   p << " : " << op.v1().getType() << ", " << op.v2().getType();
1119 }
1120 
1121 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1122 //                 `[` integer-literal (`,` integer-literal)* `]`
1123 //                 attribute-dict? `:` type
1124 static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
1125                                         OperationState &result) {
1126   llvm::SMLoc loc;
1127   OpAsmParser::OperandType v1, v2;
1128   ArrayAttr maskAttr;
1129   Type typeV1, typeV2;
1130   if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1131       parser.parseComma() || parser.parseOperand(v2) ||
1132       parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1133       parser.parseOptionalAttrDict(result.attributes) ||
1134       parser.parseColonType(typeV1) || parser.parseComma() ||
1135       parser.parseType(typeV2) ||
1136       parser.resolveOperand(v1, typeV1, result.operands) ||
1137       parser.resolveOperand(v2, typeV2, result.operands))
1138     return failure();
1139   auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
1140   if (!wrappedContainerType1 ||
1141       !wrappedContainerType1.getUnderlyingType()->isVectorTy())
1142     return parser.emitError(
1143         loc, "expected LLVM IR dialect vector type for operand #1");
1144   auto vType = LLVMType::getVectorTy(
1145       wrappedContainerType1.getVectorElementType(), maskAttr.size());
1146   result.addTypes(vType);
1147   return success();
1148 }
1149 
1150 //===----------------------------------------------------------------------===//
1151 // Implementations for LLVM::LLVMFuncOp.
1152 //===----------------------------------------------------------------------===//
1153 
1154 // Add the entry block to the function.
1155 Block *LLVMFuncOp::addEntryBlock() {
1156   assert(empty() && "function already has an entry block");
1157   assert(!isVarArg() && "unimplemented: non-external variadic functions");
1158 
1159   auto *entry = new Block;
1160   push_back(entry);
1161 
1162   LLVMType type = getType();
1163   for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i)
1164     entry->addArgument(type.getFunctionParamType(i));
1165   return entry;
1166 }
1167 
1168 void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name,
1169                        LLVMType type, LLVM::Linkage linkage,
1170                        ArrayRef<NamedAttribute> attrs,
1171                        ArrayRef<NamedAttributeList> argAttrs) {
1172   result.addRegion();
1173   result.addAttribute(SymbolTable::getSymbolAttrName(),
1174                       builder->getStringAttr(name));
1175   result.addAttribute("type", TypeAttr::get(type));
1176   result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr(
1177                                                 static_cast<int64_t>(linkage)));
1178   result.attributes.append(attrs.begin(), attrs.end());
1179   if (argAttrs.empty())
1180     return;
1181 
1182   unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams();
1183   assert(numInputs == argAttrs.size() &&
1184          "expected as many argument attribute lists as arguments");
1185   SmallString<8> argAttrName;
1186   for (unsigned i = 0; i < numInputs; ++i)
1187     if (auto argDict = argAttrs[i].getDictionary())
1188       result.addAttribute(getArgAttrName(i, argAttrName), argDict);
1189 }
1190 
1191 // Builds an LLVM function type from the given lists of input and output types.
1192 // Returns a null type if any of the types provided are non-LLVM types, or if
1193 // there is more than one output type.
1194 static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
1195                                   ArrayRef<Type> inputs, ArrayRef<Type> outputs,
1196                                   impl::VariadicFlag variadicFlag) {
1197   Builder &b = parser.getBuilder();
1198   if (outputs.size() > 1) {
1199     parser.emitError(loc, "failed to construct function type: expected zero or "
1200                           "one function result");
1201     return {};
1202   }
1203 
1204   // Convert inputs to LLVM types, exit early on error.
1205   SmallVector<LLVMType, 4> llvmInputs;
1206   for (auto t : inputs) {
1207     auto llvmTy = t.dyn_cast<LLVMType>();
1208     if (!llvmTy) {
1209       parser.emitError(loc, "failed to construct function type: expected LLVM "
1210                             "type for function arguments");
1211       return {};
1212     }
1213     llvmInputs.push_back(llvmTy);
1214   }
1215 
1216   // Get the dialect from the input type, if any exist.  Look it up in the
1217   // context otherwise.
1218   LLVMDialect *dialect =
1219       llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>()
1220                          : &llvmInputs.front().getDialect();
1221 
1222   // No output is denoted as "void" in LLVM type system.
1223   LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
1224                                         : outputs.front().dyn_cast<LLVMType>();
1225   if (!llvmOutput) {
1226     parser.emitError(loc, "failed to construct function type: expected LLVM "
1227                           "type for function results");
1228     return {};
1229   }
1230   return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
1231                                  variadicFlag.isVariadic());
1232 }
1233 
1234 // Parses an LLVM function.
1235 //
1236 // operation ::= `llvm.func` linkage? function-signature function-attributes?
1237 //               function-body
1238 //
1239 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
1240                                    OperationState &result) {
1241   // Default to external linkage if no keyword is provided.
1242   if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1243                                                getLinkageAttrName())))
1244     result.addAttribute(getLinkageAttrName(),
1245                         parser.getBuilder().getI64IntegerAttr(
1246                             static_cast<int64_t>(LLVM::Linkage::External)));
1247 
1248   StringAttr nameAttr;
1249   SmallVector<OpAsmParser::OperandType, 8> entryArgs;
1250   SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
1251   SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
1252   SmallVector<Type, 8> argTypes;
1253   SmallVector<Type, 4> resultTypes;
1254   bool isVariadic;
1255 
1256   auto signatureLocation = parser.getCurrentLocation();
1257   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1258                              result.attributes) ||
1259       impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
1260                                    argTypes, argAttrs, isVariadic, resultTypes,
1261                                    resultAttrs))
1262     return failure();
1263 
1264   auto type =
1265       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
1266                             impl::VariadicFlag(isVariadic));
1267   if (!type)
1268     return failure();
1269   result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
1270 
1271   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1272     return failure();
1273   impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
1274                              resultAttrs);
1275 
1276   auto *body = result.addRegion();
1277   return parser.parseOptionalRegion(
1278       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1279 }
1280 
1281 // Print the LLVMFuncOp. Collects argument and result types and passes them to
1282 // helper functions. Drops "void" result since it cannot be parsed back. Skips
1283 // the external linkage since it is the default value.
1284 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
1285   p << op.getOperationName() << ' ';
1286   if (op.linkage() != LLVM::Linkage::External)
1287     p << stringifyLinkage(op.linkage()) << ' ';
1288   p.printSymbolName(op.getName());
1289 
1290   LLVMType fnType = op.getType();
1291   SmallVector<Type, 8> argTypes;
1292   SmallVector<Type, 1> resTypes;
1293   argTypes.reserve(fnType.getFunctionNumParams());
1294   for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
1295     argTypes.push_back(fnType.getFunctionParamType(i));
1296 
1297   LLVMType returnType = fnType.getFunctionResultType();
1298   if (!returnType.isVoidTy())
1299     resTypes.push_back(returnType);
1300 
1301   impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
1302   impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
1303                                 {getLinkageAttrName()});
1304 
1305   // Print the body if this is not an external function.
1306   Region &body = op.body();
1307   if (!body.empty())
1308     p.printRegion(body, /*printEntryBlockArgs=*/false,
1309                   /*printBlockTerminators=*/true);
1310 }
1311 
1312 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1313 // attribute is present.  This can check for preconditions of the
1314 // getNumArguments hook not failing.
1315 LogicalResult LLVMFuncOp::verifyType() {
1316   auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
1317   if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy())
1318     return emitOpError("requires '" + getTypeAttrName() +
1319                        "' attribute of wrapped LLVM function type");
1320 
1321   return success();
1322 }
1323 
1324 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
1325 // Depends on the type attribute being correct as checked by verifyType
1326 unsigned LLVMFuncOp::getNumFuncArguments() {
1327   return getType().getUnderlyingType()->getFunctionNumParams();
1328 }
1329 
1330 // Hook for OpTrait::FunctionLike, returns the number of function results.
1331 // Depends on the type attribute being correct as checked by verifyType
1332 unsigned LLVMFuncOp::getNumFuncResults() {
1333   // We model LLVM functions that return void as having zero results,
1334   // and all others as having one result.
1335   // If we modeled a void return as one result, then it would be possible to
1336   // attach an MLIR result attribute to it, and it isn't clear what semantics we
1337   // would assign to that.
1338   if (getType().getFunctionResultType().isVoidTy())
1339     return 0;
1340   return 1;
1341 }
1342 
1343 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
1344 // - functions don't have 'common' linkage
1345 // - external functions have 'external' or 'extern_weak' linkage;
1346 // - vararg is (currently) only supported for external functions;
1347 // - entry block arguments are of LLVM types and match the function signature.
1348 static LogicalResult verify(LLVMFuncOp op) {
1349   if (op.linkage() == LLVM::Linkage::Common)
1350     return op.emitOpError()
1351            << "functions cannot have '"
1352            << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
1353 
1354   if (op.isExternal()) {
1355     if (op.linkage() != LLVM::Linkage::External &&
1356         op.linkage() != LLVM::Linkage::ExternWeak)
1357       return op.emitOpError()
1358              << "external functions must have '"
1359              << stringifyLinkage(LLVM::Linkage::External) << "' or '"
1360              << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
1361     return success();
1362   }
1363 
1364   if (op.isVarArg())
1365     return op.emitOpError("only external functions can be variadic");
1366 
1367   auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType());
1368   unsigned numArguments = funcType->getNumParams();
1369   Block &entryBlock = op.front();
1370   for (unsigned i = 0; i < numArguments; ++i) {
1371     Type argType = entryBlock.getArgument(i).getType();
1372     auto argLLVMType = argType.dyn_cast<LLVMType>();
1373     if (!argLLVMType)
1374       return op.emitOpError("entry block argument #")
1375              << i << " is not of LLVM type";
1376     if (funcType->getParamType(i) != argLLVMType.getUnderlyingType())
1377       return op.emitOpError("the type of entry block argument #")
1378              << i << " does not match the function signature";
1379   }
1380 
1381   return success();
1382 }
1383 
1384 //===----------------------------------------------------------------------===//
1385 // Verification for LLVM::NullOp.
1386 //===----------------------------------------------------------------------===//
1387 
1388 // Only LLVM pointer types are supported.
1389 static LogicalResult verify(LLVM::NullOp op) {
1390   auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>();
1391   if (!llvmType || !llvmType.isPointerTy())
1392     return op.emitOpError("expected LLVM IR pointer type");
1393   return success();
1394 }
1395 
1396 //===----------------------------------------------------------------------===//
1397 // Utility functions for parsing atomic ops
1398 //===----------------------------------------------------------------------===//
1399 
1400 // Helper function to parse a keyword into the specified attribute named by
1401 // `attrName`. The keyword must match one of the string values defined by the
1402 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
1403 // state.
1404 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
1405                                     StringRef attrName) {
1406   llvm::SMLoc loc;
1407   StringRef keyword;
1408   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
1409     return failure();
1410 
1411   // Replace the keyword `keyword` with an integer attribute.
1412   auto kind = symbolizeAtomicBinOp(keyword);
1413   if (!kind) {
1414     return parser.emitError(loc)
1415            << "'" << keyword << "' is an incorrect value of the '" << attrName
1416            << "' attribute";
1417   }
1418 
1419   auto value = static_cast<int64_t>(kind.getValue());
1420   auto attr = parser.getBuilder().getI64IntegerAttr(value);
1421   result.addAttribute(attrName, attr);
1422 
1423   return success();
1424 }
1425 
1426 // Helper function to parse a keyword into the specified attribute named by
1427 // `attrName`. The keyword must match one of the string values defined by the
1428 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
1429 // state.
1430 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
1431                                        OperationState &result,
1432                                        StringRef attrName) {
1433   llvm::SMLoc loc;
1434   StringRef ordering;
1435   if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
1436     return failure();
1437 
1438   // Replace the keyword `ordering` with an integer attribute.
1439   auto kind = symbolizeAtomicOrdering(ordering);
1440   if (!kind) {
1441     return parser.emitError(loc)
1442            << "'" << ordering << "' is an incorrect value of the '" << attrName
1443            << "' attribute";
1444   }
1445 
1446   auto value = static_cast<int64_t>(kind.getValue());
1447   auto attr = parser.getBuilder().getI64IntegerAttr(value);
1448   result.addAttribute(attrName, attr);
1449 
1450   return success();
1451 }
1452 
1453 //===----------------------------------------------------------------------===//
1454 // Printer, parser and verifier for LLVM::AtomicRMWOp.
1455 //===----------------------------------------------------------------------===//
1456 
1457 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
1458   p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' '
1459     << op.ptr() << ", " << op.val() << ' '
1460     << stringifyAtomicOrdering(op.ordering()) << ' ';
1461   p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"});
1462   p << " : " << op.res().getType();
1463 }
1464 
1465 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
1466 //                 attribute-dict? `:` type
1467 static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
1468                                     OperationState &result) {
1469   LLVMType type;
1470   OpAsmParser::OperandType ptr, val;
1471   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
1472       parser.parseComma() || parser.parseOperand(val) ||
1473       parseAtomicOrdering(parser, result, "ordering") ||
1474       parser.parseOptionalAttrDict(result.attributes) ||
1475       parser.parseColonType(type) ||
1476       parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
1477       parser.resolveOperand(val, type, result.operands))
1478     return failure();
1479 
1480   result.addTypes(type);
1481   return success();
1482 }
1483 
1484 static LogicalResult verify(AtomicRMWOp op) {
1485   auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
1486   if (!ptrType.isPointerTy())
1487     return op.emitOpError("expected LLVM IR pointer type for operand #0");
1488   auto valType = op.val().getType().cast<LLVM::LLVMType>();
1489   if (valType != ptrType.getPointerElementTy())
1490     return op.emitOpError("expected LLVM IR element type for operand #0 to "
1491                           "match type for operand #1");
1492   auto resType = op.res().getType().cast<LLVM::LLVMType>();
1493   if (resType != valType)
1494     return op.emitOpError(
1495         "expected LLVM IR result type to match type for operand #1");
1496   if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
1497     if (!valType.getUnderlyingType()->isFloatingPointTy())
1498       return op.emitOpError("expected LLVM IR floating point type");
1499   } else if (op.bin_op() == AtomicBinOp::xchg) {
1500     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1501         !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
1502         !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
1503       return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
1504   } else {
1505     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1506         !valType.isIntegerTy(32) && !valType.isIntegerTy(64))
1507       return op.emitOpError("expected LLVM IR integer type");
1508   }
1509   return success();
1510 }
1511 
1512 //===----------------------------------------------------------------------===//
1513 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
1514 //===----------------------------------------------------------------------===//
1515 
1516 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
1517   p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", "
1518     << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' '
1519     << stringifyAtomicOrdering(op.failure_ordering());
1520   p.printOptionalAttrDict(op.getAttrs(),
1521                           {"success_ordering", "failure_ordering"});
1522   p << " : " << op.val().getType();
1523 }
1524 
1525 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
1526 //                 keyword keyword attribute-dict? `:` type
1527 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
1528                                         OperationState &result) {
1529   auto &builder = parser.getBuilder();
1530   LLVMType type;
1531   OpAsmParser::OperandType ptr, cmp, val;
1532   if (parser.parseOperand(ptr) || parser.parseComma() ||
1533       parser.parseOperand(cmp) || parser.parseComma() ||
1534       parser.parseOperand(val) ||
1535       parseAtomicOrdering(parser, result, "success_ordering") ||
1536       parseAtomicOrdering(parser, result, "failure_ordering") ||
1537       parser.parseOptionalAttrDict(result.attributes) ||
1538       parser.parseColonType(type) ||
1539       parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
1540       parser.resolveOperand(cmp, type, result.operands) ||
1541       parser.resolveOperand(val, type, result.operands))
1542     return failure();
1543 
1544   auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
1545   auto boolType = LLVMType::getInt1Ty(dialect);
1546   auto resultType = LLVMType::getStructTy(type, boolType);
1547   result.addTypes(resultType);
1548 
1549   return success();
1550 }
1551 
1552 static LogicalResult verify(AtomicCmpXchgOp op) {
1553   auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
1554   if (!ptrType.isPointerTy())
1555     return op.emitOpError("expected LLVM IR pointer type for operand #0");
1556   auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
1557   auto valType = op.val().getType().cast<LLVM::LLVMType>();
1558   if (cmpType != ptrType.getPointerElementTy() || cmpType != valType)
1559     return op.emitOpError("expected LLVM IR element type for operand #0 to "
1560                           "match type for all other operands");
1561   if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
1562       !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
1563       !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() &&
1564       !valType.isDoubleTy())
1565     return op.emitOpError("unexpected LLVM IR type");
1566   if (op.success_ordering() < AtomicOrdering::monotonic ||
1567       op.failure_ordering() < AtomicOrdering::monotonic)
1568     return op.emitOpError("ordering must be at least 'monotonic'");
1569   if (op.failure_ordering() == AtomicOrdering::release ||
1570       op.failure_ordering() == AtomicOrdering::acq_rel)
1571     return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
1572   return success();
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // Printer, parser and verifier for LLVM::FenceOp.
1577 //===----------------------------------------------------------------------===//
1578 
1579 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
1580 // attribute-dict?
1581 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) {
1582   StringAttr sScope;
1583   StringRef syncscopeKeyword = "syncscope";
1584   if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
1585     if (parser.parseLParen() ||
1586         parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
1587         parser.parseRParen())
1588       return failure();
1589   } else {
1590     result.addAttribute(syncscopeKeyword,
1591                         parser.getBuilder().getStringAttr(""));
1592   }
1593   if (parseAtomicOrdering(parser, result, "ordering") ||
1594       parser.parseOptionalAttrDict(result.attributes))
1595     return failure();
1596   return success();
1597 }
1598 
1599 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
1600   StringRef syncscopeKeyword = "syncscope";
1601   p << op.getOperationName() << ' ';
1602   if (!op.getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
1603     p << "syncscope(" << op.getAttr(syncscopeKeyword) << ") ";
1604   p << stringifyAtomicOrdering(op.ordering());
1605 }
1606 
1607 static LogicalResult verify(FenceOp &op) {
1608   if (op.ordering() == AtomicOrdering::not_atomic ||
1609       op.ordering() == AtomicOrdering::unordered ||
1610       op.ordering() == AtomicOrdering::monotonic)
1611     return op.emitOpError("can be given only acquire, release, acq_rel, "
1612                           "and seq_cst orderings");
1613   return success();
1614 }
1615 
1616 //===----------------------------------------------------------------------===//
1617 // LLVMDialect initialization, type parsing, and registration.
1618 //===----------------------------------------------------------------------===//
1619 
1620 namespace mlir {
1621 namespace LLVM {
1622 namespace detail {
1623 struct LLVMDialectImpl {
1624   LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
1625 
1626   llvm::LLVMContext llvmContext;
1627   llvm::Module module;
1628 
1629   /// A set of LLVMTypes that are cached on construction to avoid any lookups or
1630   /// locking.
1631   LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
1632   LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty;
1633   LLVMType voidTy;
1634 
1635   /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
1636   /// multi-threaded and requires locked access to prevent race conditions.
1637   llvm::sys::SmartMutex<true> mutex;
1638 };
1639 } // end namespace detail
1640 } // end namespace LLVM
1641 } // end namespace mlir
1642 
1643 LLVMDialect::LLVMDialect(MLIRContext *context)
1644     : Dialect(getDialectNamespace(), context),
1645       impl(new detail::LLVMDialectImpl()) {
1646   addTypes<LLVMType>();
1647   addOperations<
1648 #define GET_OP_LIST
1649 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1650       >();
1651 
1652   // Support unknown operations because not all LLVM operations are registered.
1653   allowUnknownOperations();
1654 
1655   // Cache some of the common LLVM types to avoid the need for lookups/locking.
1656   auto &llvmContext = impl->llvmContext;
1657   /// Integer Types.
1658   impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
1659   impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
1660   impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext));
1661   impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext));
1662   impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext));
1663   impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext));
1664   /// Float Types.
1665   impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
1666   impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
1667   impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
1668   impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext));
1669   impl->x86_fp80Ty =
1670       LLVMType::get(context, llvm::Type::getX86_FP80Ty(llvmContext));
1671   /// Other Types.
1672   impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext));
1673 }
1674 
1675 LLVMDialect::~LLVMDialect() {}
1676 
1677 #define GET_OP_CLASSES
1678 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1679 
1680 llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
1681 llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
1682 
1683 /// Parse a type registered to this dialect.
1684 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
1685   StringRef tyData = parser.getFullSymbolSpec();
1686 
1687   // LLVM is not thread-safe, so lock access to it.
1688   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
1689 
1690   llvm::SMDiagnostic errorMessage;
1691   llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
1692   if (!type)
1693     return (parser.emitError(parser.getNameLoc(), errorMessage.getMessage()),
1694             nullptr);
1695   return LLVMType::get(getContext(), type);
1696 }
1697 
1698 /// Print a type registered to this dialect.
1699 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
1700   auto llvmType = type.dyn_cast<LLVMType>();
1701   assert(llvmType && "printing wrong type");
1702   assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
1703   llvmType.getUnderlyingType()->print(os.getStream());
1704 }
1705 
1706 /// Verify LLVMIR function argument attributes.
1707 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
1708                                                     unsigned regionIdx,
1709                                                     unsigned argIdx,
1710                                                     NamedAttribute argAttr) {
1711   // Check that llvm.noalias is a boolean attribute.
1712   if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
1713     return op->emitError()
1714            << "llvm.noalias argument attribute of non boolean type";
1715   return success();
1716 }
1717 
1718 //===----------------------------------------------------------------------===//
1719 // LLVMType.
1720 //===----------------------------------------------------------------------===//
1721 
1722 namespace mlir {
1723 namespace LLVM {
1724 namespace detail {
1725 struct LLVMTypeStorage : public ::mlir::TypeStorage {
1726   LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
1727 
1728   // LLVM types are pointer-unique.
1729   using KeyTy = llvm::Type *;
1730   bool operator==(const KeyTy &key) const { return key == underlyingType; }
1731 
1732   static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
1733                                     llvm::Type *ty) {
1734     return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
1735   }
1736 
1737   llvm::Type *underlyingType;
1738 };
1739 } // end namespace detail
1740 } // end namespace LLVM
1741 } // end namespace mlir
1742 
1743 LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
1744   return Base::get(context, FIRST_LLVM_TYPE, llvmType);
1745 }
1746 
1747 /// Get an LLVMType with an llvm type that may cause changes to the underlying
1748 /// llvm context when constructed.
1749 LLVMType LLVMType::getLocked(LLVMDialect *dialect,
1750                              function_ref<llvm::Type *()> typeBuilder) {
1751   // Lock access to the llvm context and build the type.
1752   llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
1753   return get(dialect->getContext(), typeBuilder());
1754 }
1755 
1756 LLVMDialect &LLVMType::getDialect() {
1757   return static_cast<LLVMDialect &>(Type::getDialect());
1758 }
1759 
1760 llvm::Type *LLVMType::getUnderlyingType() const {
1761   return getImpl()->underlyingType;
1762 }
1763 
1764 /// Array type utilities.
1765 LLVMType LLVMType::getArrayElementType() {
1766   return get(getContext(), getUnderlyingType()->getArrayElementType());
1767 }
1768 unsigned LLVMType::getArrayNumElements() {
1769   return getUnderlyingType()->getArrayNumElements();
1770 }
1771 bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
1772 
1773 /// Vector type utilities.
1774 LLVMType LLVMType::getVectorElementType() {
1775   return get(getContext(), getUnderlyingType()->getVectorElementType());
1776 }
1777 bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
1778 
1779 /// Function type utilities.
1780 LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
1781   return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx));
1782 }
1783 unsigned LLVMType::getFunctionNumParams() {
1784   return getUnderlyingType()->getFunctionNumParams();
1785 }
1786 LLVMType LLVMType::getFunctionResultType() {
1787   return get(
1788       getContext(),
1789       llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
1790 }
1791 bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); }
1792 
1793 /// Pointer type utilities.
1794 LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
1795   // Lock access to the dialect as this may modify the LLVM context.
1796   return getLocked(&getDialect(), [=] {
1797     return getUnderlyingType()->getPointerTo(addrSpace);
1798   });
1799 }
1800 LLVMType LLVMType::getPointerElementTy() {
1801   return get(getContext(), getUnderlyingType()->getPointerElementType());
1802 }
1803 bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
1804 
1805 /// Struct type utilities.
1806 LLVMType LLVMType::getStructElementType(unsigned i) {
1807   return get(getContext(), getUnderlyingType()->getStructElementType(i));
1808 }
1809 unsigned LLVMType::getStructNumElements() {
1810   return getUnderlyingType()->getStructNumElements();
1811 }
1812 bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); }
1813 
1814 /// Utilities used to generate floating point types.
1815 LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
1816   return dialect->impl->doubleTy;
1817 }
1818 LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
1819   return dialect->impl->floatTy;
1820 }
1821 LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
1822   return dialect->impl->halfTy;
1823 }
1824 LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) {
1825   return dialect->impl->fp128Ty;
1826 }
1827 LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) {
1828   return dialect->impl->x86_fp80Ty;
1829 }
1830 
1831 /// Utilities used to generate integer types.
1832 LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
1833   switch (numBits) {
1834   case 1:
1835     return dialect->impl->int1Ty;
1836   case 8:
1837     return dialect->impl->int8Ty;
1838   case 16:
1839     return dialect->impl->int16Ty;
1840   case 32:
1841     return dialect->impl->int32Ty;
1842   case 64:
1843     return dialect->impl->int64Ty;
1844   case 128:
1845     return dialect->impl->int128Ty;
1846   default:
1847     break;
1848   }
1849 
1850   // Lock access to the dialect as this may modify the LLVM context.
1851   return getLocked(dialect, [=] {
1852     return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits);
1853   });
1854 }
1855 
1856 /// Utilities used to generate other miscellaneous types.
1857 LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
1858   // Lock access to the dialect as this may modify the LLVM context.
1859   return getLocked(&elementType.getDialect(), [=] {
1860     return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements);
1861   });
1862 }
1863 LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
1864                                  bool isVarArg) {
1865   SmallVector<llvm::Type *, 8> llvmParams;
1866   for (auto param : params)
1867     llvmParams.push_back(param.getUnderlyingType());
1868 
1869   // Lock access to the dialect as this may modify the LLVM context.
1870   return getLocked(&result.getDialect(), [=] {
1871     return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
1872                                    isVarArg);
1873   });
1874 }
1875 LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
1876                                ArrayRef<LLVMType> elements, bool isPacked) {
1877   SmallVector<llvm::Type *, 8> llvmElements;
1878   for (auto elt : elements)
1879     llvmElements.push_back(elt.getUnderlyingType());
1880 
1881   // Lock access to the dialect as this may modify the LLVM context.
1882   return getLocked(dialect, [=] {
1883     return llvm::StructType::get(dialect->getLLVMContext(), llvmElements,
1884                                  isPacked);
1885   });
1886 }
1887 inline static SmallVector<llvm::Type *, 8>
1888 toUnderlyingTypes(ArrayRef<LLVMType> elements) {
1889   SmallVector<llvm::Type *, 8> llvmElements;
1890   for (auto elt : elements)
1891     llvmElements.push_back(elt.getUnderlyingType());
1892   return llvmElements;
1893 }
1894 LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
1895                                   ArrayRef<LLVMType> elements,
1896                                   Optional<StringRef> name, bool isPacked) {
1897   StringRef sr = name.hasValue() ? *name : "";
1898   SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
1899   return getLocked(dialect, [=] {
1900     auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr);
1901     if (!llvmElements.empty())
1902       rv->setBody(llvmElements, isPacked);
1903     return rv;
1904   });
1905 }
1906 LLVMType LLVMType::setStructTyBody(LLVMType structType,
1907                                    ArrayRef<LLVMType> elements, bool isPacked) {
1908   llvm::StructType *st =
1909       llvm::cast<llvm::StructType>(structType.getUnderlyingType());
1910   SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
1911   return getLocked(&structType.getDialect(), [=] {
1912     st->setBody(llvmElements, isPacked);
1913     return st;
1914   });
1915 }
1916 LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
1917   // Lock access to the dialect as this may modify the LLVM context.
1918   return getLocked(&elementType.getDialect(), [=] {
1919     return llvm::VectorType::get(elementType.getUnderlyingType(), numElements);
1920   });
1921 }
1922 
1923 LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
1924   return dialect->impl->voidTy;
1925 }
1926 
1927 bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); }
1928 
1929 //===----------------------------------------------------------------------===//
1930 // Utility functions.
1931 //===----------------------------------------------------------------------===//
1932 
1933 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
1934                                      StringRef name, StringRef value,
1935                                      LLVM::Linkage linkage,
1936                                      LLVM::LLVMDialect *llvmDialect) {
1937   assert(builder.getInsertionBlock() &&
1938          builder.getInsertionBlock()->getParentOp() &&
1939          "expected builder to point to a block constrained in an op");
1940   auto module =
1941       builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
1942   assert(module && "builder points to an op outside of a module");
1943 
1944   // Create the global at the entry of the module.
1945   OpBuilder moduleBuilder(module.getBodyRegion());
1946   auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
1947                                          value.size());
1948   auto global = moduleBuilder.create<LLVM::GlobalOp>(
1949       loc, type, /*isConstant=*/true, linkage, name,
1950       builder.getStringAttr(value));
1951 
1952   // Get the pointer to the first character in the global string.
1953   Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
1954   Value cst0 = builder.create<LLVM::ConstantOp>(
1955       loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
1956       builder.getIntegerAttr(builder.getIndexType(), 0));
1957   return builder.create<LLVM::GEPOp>(loc,
1958                                      LLVM::LLVMType::getInt8PtrTy(llvmDialect),
1959                                      globalPtr, ArrayRef<Value>({cst0, cst0}));
1960 }
1961 
1962 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
1963   return op->hasTrait<OpTrait::SymbolTable>() &&
1964          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
1965 }
1966