1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "TypeDetail.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/DialectImplementation.h"
20 #include "mlir/IR/FunctionImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Matchers.h"
23
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/Bitcode/BitcodeReader.h"
28 #include "llvm/Bitcode/BitcodeWriter.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Error.h"
33 #include "llvm/Support/Mutex.h"
34 #include "llvm/Support/SourceMgr.h"
35
36 #include <numeric>
37
38 using namespace mlir;
39 using namespace mlir::LLVM;
40 using mlir::LLVM::cconv::getMaxEnumValForCConv;
41 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
42
43 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
44
45 static constexpr const char kVolatileAttrName[] = "volatile_";
46 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
47 static constexpr const char kElemTypeAttrName[] = "elem_type";
48
49 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
50 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
51 #define GET_ATTRDEF_CLASSES
52 #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
53
processFMFAttr(ArrayRef<NamedAttribute> attrs)54 static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
55 SmallVector<NamedAttribute, 8> filteredAttrs(
56 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
57 if (attr.getName() == "fastmathFlags") {
58 auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
59 return defAttr != attr.getValue();
60 }
61 return true;
62 }));
63 return filteredAttrs;
64 }
65
parseLLVMOpAttrs(OpAsmParser & parser,NamedAttrList & result)66 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
67 NamedAttrList &result) {
68 return parser.parseOptionalAttrDict(result);
69 }
70
printLLVMOpAttrs(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)71 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
72 DictionaryAttr attrs) {
73 printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
74 }
75
76 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
77 /// fully defined llvm.func.
verifySymbolAttrUse(FlatSymbolRefAttr symbol,Operation * op,SymbolTableCollection & symbolTable)78 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
79 Operation *op,
80 SymbolTableCollection &symbolTable) {
81 StringRef name = symbol.getValue();
82 auto func =
83 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
84 if (!func)
85 return op->emitOpError("'")
86 << name << "' does not reference a valid LLVM function";
87 if (func.isExternal())
88 return op->emitOpError("'") << name << "' does not have a definition";
89 return success();
90 }
91
92 //===----------------------------------------------------------------------===//
93 // Printing, parsing and builder for LLVM::CmpOp.
94 //===----------------------------------------------------------------------===//
95
build(OpBuilder & builder,OperationState & result,ICmpPredicate predicate,Value lhs,Value rhs)96 void ICmpOp::build(OpBuilder &builder, OperationState &result,
97 ICmpPredicate predicate, Value lhs, Value rhs) {
98 auto boolType = IntegerType::get(lhs.getType().getContext(), 1);
99 if (LLVM::isCompatibleVectorType(lhs.getType()) ||
100 LLVM::isCompatibleVectorType(rhs.getType())) {
101 int64_t numLHSElements = 1, numRHSElements = 1;
102 if (LLVM::isCompatibleVectorType(lhs.getType()))
103 numLHSElements =
104 LLVM::getVectorNumElements(lhs.getType()).getFixedValue();
105 if (LLVM::isCompatibleVectorType(rhs.getType()))
106 numRHSElements =
107 LLVM::getVectorNumElements(rhs.getType()).getFixedValue();
108 build(builder, result,
109 VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType),
110 predicate, lhs, rhs);
111 } else {
112 build(builder, result, boolType, predicate, lhs, rhs);
113 }
114 }
115
print(OpAsmPrinter & p)116 void ICmpOp::print(OpAsmPrinter &p) {
117 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
118 << ", " << getOperand(1);
119 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
120 p << " : " << getLhs().getType();
121 }
122
print(OpAsmPrinter & p)123 void FCmpOp::print(OpAsmPrinter &p) {
124 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
125 << ", " << getOperand(1);
126 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
127 p << " : " << getLhs().getType();
128 }
129
130 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
131 // attribute-dict? `:` type
132 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
133 // attribute-dict? `:` type
134 template <typename CmpPredicateType>
parseCmpOp(OpAsmParser & parser,OperationState & result)135 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
136 Builder &builder = parser.getBuilder();
137
138 StringAttr predicateAttr;
139 OpAsmParser::UnresolvedOperand lhs, rhs;
140 Type type;
141 SMLoc predicateLoc, trailingTypeLoc;
142 if (parser.getCurrentLocation(&predicateLoc) ||
143 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
144 parser.parseOperand(lhs) || parser.parseComma() ||
145 parser.parseOperand(rhs) ||
146 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
147 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
148 parser.resolveOperand(lhs, type, result.operands) ||
149 parser.resolveOperand(rhs, type, result.operands))
150 return failure();
151
152 // Replace the string attribute `predicate` with an integer attribute.
153 int64_t predicateValue = 0;
154 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
155 Optional<ICmpPredicate> predicate =
156 symbolizeICmpPredicate(predicateAttr.getValue());
157 if (!predicate)
158 return parser.emitError(predicateLoc)
159 << "'" << predicateAttr.getValue()
160 << "' is an incorrect value of the 'predicate' attribute";
161 predicateValue = static_cast<int64_t>(*predicate);
162 } else {
163 Optional<FCmpPredicate> predicate =
164 symbolizeFCmpPredicate(predicateAttr.getValue());
165 if (!predicate)
166 return parser.emitError(predicateLoc)
167 << "'" << predicateAttr.getValue()
168 << "' is an incorrect value of the 'predicate' attribute";
169 predicateValue = static_cast<int64_t>(*predicate);
170 }
171
172 result.attributes.set("predicate",
173 parser.getBuilder().getI64IntegerAttr(predicateValue));
174
175 // The result type is either i1 or a vector type <? x i1> if the inputs are
176 // vectors.
177 Type resultType = IntegerType::get(builder.getContext(), 1);
178 if (!isCompatibleType(type))
179 return parser.emitError(trailingTypeLoc,
180 "expected LLVM dialect-compatible type");
181 if (LLVM::isCompatibleVectorType(type)) {
182 if (LLVM::isScalableVectorType(type)) {
183 resultType = LLVM::getVectorType(
184 resultType, LLVM::getVectorNumElements(type).getKnownMinValue(),
185 /*isScalable=*/true);
186 } else {
187 resultType = LLVM::getVectorType(
188 resultType, LLVM::getVectorNumElements(type).getFixedValue(),
189 /*isScalable=*/false);
190 }
191 }
192
193 result.addTypes({resultType});
194 return success();
195 }
196
parse(OpAsmParser & parser,OperationState & result)197 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
198 return parseCmpOp<ICmpPredicate>(parser, result);
199 }
200
parse(OpAsmParser & parser,OperationState & result)201 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
202 return parseCmpOp<FCmpPredicate>(parser, result);
203 }
204
205 //===----------------------------------------------------------------------===//
206 // Printing, parsing and verification for LLVM::AllocaOp.
207 //===----------------------------------------------------------------------===//
208
print(OpAsmPrinter & p)209 void AllocaOp::print(OpAsmPrinter &p) {
210 Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType();
211 if (!elemTy)
212 elemTy = *getElemType();
213
214 auto funcTy =
215 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
216
217 p << ' ' << getArraySize() << " x " << elemTy;
218 if (getAlignment() && *getAlignment() != 0)
219 p.printOptionalAttrDict((*this)->getAttrs(), {kElemTypeAttrName});
220 else
221 p.printOptionalAttrDict((*this)->getAttrs(),
222 {"alignment", kElemTypeAttrName});
223 p << " : " << funcTy;
224 }
225
226 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
227 // `:` type `,` type
parse(OpAsmParser & parser,OperationState & result)228 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
229 OpAsmParser::UnresolvedOperand arraySize;
230 Type type, elemType;
231 SMLoc trailingTypeLoc;
232 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
233 parser.parseType(elemType) ||
234 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
235 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
236 return failure();
237
238 Optional<NamedAttribute> alignmentAttr =
239 result.attributes.getNamed("alignment");
240 if (alignmentAttr.has_value()) {
241 auto alignmentInt =
242 alignmentAttr.value().getValue().dyn_cast<IntegerAttr>();
243 if (!alignmentInt)
244 return parser.emitError(parser.getNameLoc(),
245 "expected integer alignment");
246 if (alignmentInt.getValue().isNullValue())
247 result.attributes.erase("alignment");
248 }
249
250 // Extract the result type from the trailing function type.
251 auto funcType = type.dyn_cast<FunctionType>();
252 if (!funcType || funcType.getNumInputs() != 1 ||
253 funcType.getNumResults() != 1)
254 return parser.emitError(
255 trailingTypeLoc,
256 "expected trailing function type with one argument and one result");
257
258 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
259 return failure();
260
261 Type resultType = funcType.getResult(0);
262 if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) {
263 if (ptrResultType.isOpaque())
264 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
265 }
266
267 result.addTypes({funcType.getResult(0)});
268 return success();
269 }
270
271 /// Checks that the elemental type is present in either the pointer type or
272 /// the attribute, but not both.
verifyOpaquePtr(Operation * op,LLVMPointerType ptrType,Optional<Type> ptrElementType)273 static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType,
274 Optional<Type> ptrElementType) {
275 if (ptrType.isOpaque() && !ptrElementType.has_value()) {
276 return op->emitOpError() << "expected '" << kElemTypeAttrName
277 << "' attribute if opaque pointer type is used";
278 }
279 if (!ptrType.isOpaque() && ptrElementType.has_value()) {
280 return op->emitOpError()
281 << "unexpected '" << kElemTypeAttrName
282 << "' attribute when non-opaque pointer type is used";
283 }
284 return success();
285 }
286
verify()287 LogicalResult AllocaOp::verify() {
288 return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(),
289 getElemType());
290 }
291
292 //===----------------------------------------------------------------------===//
293 // LLVM::BrOp
294 //===----------------------------------------------------------------------===//
295
getSuccessorOperands(unsigned index)296 SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
297 assert(index == 0 && "invalid successor index");
298 return SuccessorOperands(getDestOperandsMutable());
299 }
300
301 //===----------------------------------------------------------------------===//
302 // LLVM::CondBrOp
303 //===----------------------------------------------------------------------===//
304
getSuccessorOperands(unsigned index)305 SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
306 assert(index < getNumSuccessors() && "invalid successor index");
307 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
308 : getFalseDestOperandsMutable());
309 }
310
311 //===----------------------------------------------------------------------===//
312 // LLVM::SwitchOp
313 //===----------------------------------------------------------------------===//
314
build(OpBuilder & builder,OperationState & result,Value value,Block * defaultDestination,ValueRange defaultOperands,ArrayRef<int32_t> caseValues,BlockRange caseDestinations,ArrayRef<ValueRange> caseOperands,ArrayRef<int32_t> branchWeights)315 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
316 Block *defaultDestination, ValueRange defaultOperands,
317 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
318 ArrayRef<ValueRange> caseOperands,
319 ArrayRef<int32_t> branchWeights) {
320 ElementsAttr caseValuesAttr;
321 if (!caseValues.empty())
322 caseValuesAttr = builder.getI32VectorAttr(caseValues);
323
324 ElementsAttr weightsAttr;
325 if (!branchWeights.empty())
326 weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
327
328 build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
329 weightsAttr, defaultDestination, caseDestinations);
330 }
331
332 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
333 /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
parseSwitchOpCases(OpAsmParser & parser,Type flagType,ElementsAttr & caseValues,SmallVectorImpl<Block * > & caseDestinations,SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> & caseOperands,SmallVectorImpl<SmallVector<Type>> & caseOperandTypes)334 static ParseResult parseSwitchOpCases(
335 OpAsmParser &parser, Type flagType, ElementsAttr &caseValues,
336 SmallVectorImpl<Block *> &caseDestinations,
337 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
338 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
339 SmallVector<APInt> values;
340 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
341 do {
342 int64_t value = 0;
343 OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
344 if (values.empty() && !integerParseResult.hasValue())
345 return success();
346
347 if (!integerParseResult.hasValue() || integerParseResult.getValue())
348 return failure();
349 values.push_back(APInt(bitWidth, value));
350
351 Block *destination;
352 SmallVector<OpAsmParser::UnresolvedOperand> operands;
353 SmallVector<Type> operandTypes;
354 if (parser.parseColon() || parser.parseSuccessor(destination))
355 return failure();
356 if (!parser.parseOptionalLParen()) {
357 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
358 /*allowResultNumber=*/false) ||
359 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
360 return failure();
361 }
362 caseDestinations.push_back(destination);
363 caseOperands.emplace_back(operands);
364 caseOperandTypes.emplace_back(operandTypes);
365 } while (!parser.parseOptionalComma());
366
367 ShapedType caseValueType =
368 VectorType::get(static_cast<int64_t>(values.size()), flagType);
369 caseValues = DenseIntElementsAttr::get(caseValueType, values);
370 return success();
371 }
372
printSwitchOpCases(OpAsmPrinter & p,SwitchOp op,Type flagType,ElementsAttr caseValues,SuccessorRange caseDestinations,OperandRangeRange caseOperands,const TypeRangeRange & caseOperandTypes)373 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
374 ElementsAttr caseValues,
375 SuccessorRange caseDestinations,
376 OperandRangeRange caseOperands,
377 const TypeRangeRange &caseOperandTypes) {
378 if (!caseValues)
379 return;
380
381 size_t index = 0;
382 llvm::interleave(
383 llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
384 [&](auto i) {
385 p << " ";
386 p << std::get<0>(i).getLimitedValue();
387 p << ": ";
388 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
389 },
390 [&] {
391 p << ',';
392 p.printNewline();
393 });
394 p.printNewline();
395 }
396
verify()397 LogicalResult SwitchOp::verify() {
398 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
399 (getCaseValues() &&
400 getCaseValues()->size() !=
401 static_cast<int64_t>(getCaseDestinations().size())))
402 return emitOpError("expects number of case values to match number of "
403 "case destinations");
404 if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
405 return emitError("expects number of branch weights to match number of "
406 "successors: ")
407 << getBranchWeights()->size() << " vs " << getNumSuccessors();
408 return success();
409 }
410
getSuccessorOperands(unsigned index)411 SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
412 assert(index < getNumSuccessors() && "invalid successor index");
413 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
414 : getCaseOperandsMutable(index - 1));
415 }
416
417 //===----------------------------------------------------------------------===//
418 // Code for LLVM::GEPOp.
419 //===----------------------------------------------------------------------===//
420
421 constexpr int GEPOp::kDynamicIndex;
422
423 namespace {
424 /// Base class for llvm::Error related to GEP index.
425 class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> {
426 protected:
427 unsigned indexPos;
428
429 public:
430 static char ID;
431
convertToErrorCode() const432 std::error_code convertToErrorCode() const override {
433 return llvm::inconvertibleErrorCode();
434 }
435
GEPIndexError(unsigned pos)436 explicit GEPIndexError(unsigned pos) : indexPos(pos) {}
437 };
438
439 /// llvm::Error for out-of-bound GEP index.
440 struct GEPIndexOutOfBoundError
441 : public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> {
442 static char ID;
443
444 using ErrorInfo::ErrorInfo;
445
log__anoncc58d73e0411::GEPIndexOutOfBoundError446 void log(llvm::raw_ostream &os) const override {
447 os << "index " << indexPos << " indexing a struct is out of bounds";
448 }
449 };
450
451 /// llvm::Error for non-static GEP index indexing a struct.
452 struct GEPStaticIndexError
453 : public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> {
454 static char ID;
455
456 using ErrorInfo::ErrorInfo;
457
log__anoncc58d73e0411::GEPStaticIndexError458 void log(llvm::raw_ostream &os) const override {
459 os << "expected index " << indexPos << " indexing a struct "
460 << "to be constant";
461 }
462 };
463 } // end anonymous namespace
464
465 char GEPIndexError::ID = 0;
466 char GEPIndexOutOfBoundError::ID = 0;
467 char GEPStaticIndexError::ID = 0;
468
469 /// For the given `structIndices` and `indices`, check if they're complied
470 /// with `baseGEPType`, especially check against LLVMStructTypes nested within,
471 /// and refine/promote struct index from `indices` to `updatedStructIndices`
472 /// if the latter argument is not null.
473 static llvm::Error
recordStructIndices(Type baseGEPType,unsigned indexPos,ArrayRef<int32_t> structIndices,ValueRange indices,SmallVectorImpl<int32_t> * updatedStructIndices,SmallVectorImpl<Value> * remainingIndices)474 recordStructIndices(Type baseGEPType, unsigned indexPos,
475 ArrayRef<int32_t> structIndices, ValueRange indices,
476 SmallVectorImpl<int32_t> *updatedStructIndices,
477 SmallVectorImpl<Value> *remainingIndices) {
478 if (indexPos >= structIndices.size())
479 // Stop searching
480 return llvm::Error::success();
481
482 int32_t gepIndex = structIndices[indexPos];
483 bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex;
484
485 unsigned dynamicIndexPos = indexPos;
486 if (!isStaticIndex)
487 dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1),
488 LLVM::GEPOp::kDynamicIndex) -
489 1;
490
491 return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
492 .Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
493 // We don't always want to refine the index (e.g. when performing
494 // verification), so we only refine when updatedStructIndices is not
495 // null.
496 if (!isStaticIndex && updatedStructIndices) {
497 // Try to refine.
498 APInt staticIndexValue;
499 isStaticIndex = matchPattern(indices[dynamicIndexPos],
500 m_ConstantInt(&staticIndexValue));
501 if (isStaticIndex) {
502 assert(staticIndexValue.getBitWidth() <= 64 &&
503 llvm::isInt<32>(staticIndexValue.getLimitedValue()) &&
504 "struct index can't fit within int32_t");
505 gepIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
506 }
507 }
508 if (!isStaticIndex)
509 return llvm::make_error<GEPStaticIndexError>(indexPos);
510
511 ArrayRef<Type> elementTypes = structType.getBody();
512 if (gepIndex < 0 ||
513 static_cast<size_t>(gepIndex) >= elementTypes.size())
514 return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
515
516 if (updatedStructIndices)
517 (*updatedStructIndices)[indexPos] = gepIndex;
518
519 // Instead of recusively going into every children types, we only
520 // dive into the one indexed by gepIndex.
521 return recordStructIndices(elementTypes[gepIndex], indexPos + 1,
522 structIndices, indices, updatedStructIndices,
523 remainingIndices);
524 })
525 .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
526 LLVMArrayType>([&](auto containerType) -> llvm::Error {
527 // Currently we don't refine non-struct index even if it's static.
528 if (remainingIndices)
529 remainingIndices->push_back(indices[dynamicIndexPos]);
530 return recordStructIndices(containerType.getElementType(), indexPos + 1,
531 structIndices, indices, updatedStructIndices,
532 remainingIndices);
533 })
534 .Default(
535 [](auto otherType) -> llvm::Error { return llvm::Error::success(); });
536 }
537
538 /// Driver function around `recordStructIndices`. Note that we always check
539 /// from the second GEP index since the first one is always dynamic.
540 static llvm::Error
findStructIndices(Type baseGEPType,ArrayRef<int32_t> structIndices,ValueRange indices,SmallVectorImpl<int32_t> * updatedStructIndices=nullptr,SmallVectorImpl<Value> * remainingIndices=nullptr)541 findStructIndices(Type baseGEPType, ArrayRef<int32_t> structIndices,
542 ValueRange indices,
543 SmallVectorImpl<int32_t> *updatedStructIndices = nullptr,
544 SmallVectorImpl<Value> *remainingIndices = nullptr) {
545 if (remainingIndices)
546 // The first GEP index is always dynamic.
547 remainingIndices->push_back(indices[0]);
548 return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices,
549 indices, updatedStructIndices, remainingIndices);
550 }
551
build(OpBuilder & builder,OperationState & result,Type resultType,Value basePtr,ValueRange operands,ArrayRef<NamedAttribute> attributes)552 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
553 Value basePtr, ValueRange operands,
554 ArrayRef<NamedAttribute> attributes) {
555 build(builder, result, resultType, basePtr, operands,
556 SmallVector<int32_t>(operands.size(), kDynamicIndex), attributes);
557 }
558
559 /// Returns the elemental type of any LLVM-compatible vector type or self.
extractVectorElementType(Type type)560 static Type extractVectorElementType(Type type) {
561 if (auto vectorType = type.dyn_cast<VectorType>())
562 return vectorType.getElementType();
563 if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
564 return scalableVectorType.getElementType();
565 if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
566 return fixedVectorType.getElementType();
567 return type;
568 }
569
build(OpBuilder & builder,OperationState & result,Type resultType,Type elementType,Value basePtr,ValueRange indices,ArrayRef<NamedAttribute> attributes)570 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
571 Type elementType, Value basePtr, ValueRange indices,
572 ArrayRef<NamedAttribute> attributes) {
573 build(builder, result, resultType, elementType, basePtr, indices,
574 SmallVector<int32_t>(indices.size(), kDynamicIndex), attributes);
575 }
576
build(OpBuilder & builder,OperationState & result,Type resultType,Value basePtr,ValueRange indices,ArrayRef<int32_t> structIndices,ArrayRef<NamedAttribute> attributes)577 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
578 Value basePtr, ValueRange indices,
579 ArrayRef<int32_t> structIndices,
580 ArrayRef<NamedAttribute> attributes) {
581 auto ptrType =
582 extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>();
583 assert(!ptrType.isOpaque() &&
584 "expected non-opaque pointer, provide elementType explicitly when "
585 "opaque pointers are used");
586 build(builder, result, resultType, ptrType.getElementType(), basePtr, indices,
587 structIndices, attributes);
588 }
589
build(OpBuilder & builder,OperationState & result,Type resultType,Type elementType,Value basePtr,ValueRange indices,ArrayRef<int32_t> structIndices,ArrayRef<NamedAttribute> attributes)590 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
591 Type elementType, Value basePtr, ValueRange indices,
592 ArrayRef<int32_t> structIndices,
593 ArrayRef<NamedAttribute> attributes) {
594 SmallVector<Value> remainingIndices;
595 SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
596 structIndices.end());
597 if (llvm::Error err =
598 findStructIndices(elementType, structIndices, indices,
599 &updatedStructIndices, &remainingIndices))
600 llvm::report_fatal_error(StringRef(llvm::toString(std::move(err))));
601
602 assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
603 updatedStructIndices, kDynamicIndex)) &&
604 "expected as many index operands as dynamic index attr elements");
605
606 result.addTypes(resultType);
607 result.addAttributes(attributes);
608 result.addAttribute("structIndices",
609 builder.getI32TensorAttr(updatedStructIndices));
610 if (extractVectorElementType(basePtr.getType())
611 .cast<LLVMPointerType>()
612 .isOpaque())
613 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
614 result.addOperands(basePtr);
615 result.addOperands(remainingIndices);
616 }
617
618 static ParseResult
parseGEPIndices(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & indices,DenseIntElementsAttr & structIndices)619 parseGEPIndices(OpAsmParser &parser,
620 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
621 DenseIntElementsAttr &structIndices) {
622 SmallVector<int32_t> constantIndices;
623
624 auto idxParser = [&]() -> ParseResult {
625 int32_t constantIndex;
626 OptionalParseResult parsedInteger =
627 parser.parseOptionalInteger(constantIndex);
628 if (parsedInteger.hasValue()) {
629 if (failed(parsedInteger.getValue()))
630 return failure();
631 constantIndices.push_back(constantIndex);
632 return success();
633 }
634
635 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
636 return parser.parseOperand(indices.emplace_back());
637 };
638 if (parser.parseCommaSeparatedList(idxParser))
639 return failure();
640
641 structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
642 return success();
643 }
644
printGEPIndices(OpAsmPrinter & printer,LLVM::GEPOp gepOp,OperandRange indices,DenseIntElementsAttr structIndices)645 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
646 OperandRange indices,
647 DenseIntElementsAttr structIndices) {
648 unsigned operandIdx = 0;
649 llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
650 [&](int32_t cst) {
651 if (cst == LLVM::GEPOp::kDynamicIndex)
652 printer.printOperand(indices[operandIdx++]);
653 else
654 printer << cst;
655 });
656 }
657
verify()658 LogicalResult LLVM::GEPOp::verify() {
659 if (failed(verifyOpaquePtr(
660 getOperation(),
661 extractVectorElementType(getType()).cast<LLVMPointerType>(),
662 getElemType())))
663 return failure();
664
665 auto structIndexRange = getStructIndices().getValues<int32_t>();
666 // structIndexRange is a kind of iterator, which cannot be converted
667 // to ArrayRef directly.
668 SmallVector<int32_t> structIndices(structIndexRange.size());
669 for (unsigned i : llvm::seq<unsigned>(0, structIndexRange.size()))
670 structIndices[i] = structIndexRange[i];
671 if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices,
672 getIndices()))
673 return emitOpError() << llvm::toString(std::move(err));
674
675 return success();
676 }
677
getSourceElementType()678 Type LLVM::GEPOp::getSourceElementType() {
679 if (Optional<Type> elemType = getElemType())
680 return *elemType;
681
682 return extractVectorElementType(getBase().getType())
683 .cast<LLVMPointerType>()
684 .getElementType();
685 }
686
687 //===----------------------------------------------------------------------===//
688 // Builder, printer and parser for for LLVM::LoadOp.
689 //===----------------------------------------------------------------------===//
690
verifySymbolAttribute(Operation * op,StringRef attributeName,llvm::function_ref<LogicalResult (Operation *,SymbolRefAttr)> verifySymbolType)691 LogicalResult verifySymbolAttribute(
692 Operation *op, StringRef attributeName,
693 llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
694 verifySymbolType) {
695 if (Attribute attribute = op->getAttr(attributeName)) {
696 // The attribute is already verified to be a symbol ref array attribute via
697 // a constraint in the operation definition.
698 for (SymbolRefAttr symbolRef :
699 attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
700 StringAttr metadataName = symbolRef.getRootReference();
701 StringAttr symbolName = symbolRef.getLeafReference();
702 // We want @metadata::@symbol, not just @symbol
703 if (metadataName == symbolName) {
704 return op->emitOpError() << "expected '" << symbolRef
705 << "' to specify a fully qualified reference";
706 }
707 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
708 op->getParentOp(), metadataName);
709 if (!metadataOp)
710 return op->emitOpError()
711 << "expected '" << symbolRef << "' to reference a metadata op";
712 Operation *symbolOp =
713 SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
714 if (!symbolOp)
715 return op->emitOpError()
716 << "expected '" << symbolRef << "' to be a valid reference";
717 if (failed(verifySymbolType(symbolOp, symbolRef))) {
718 return failure();
719 }
720 }
721 }
722 return success();
723 }
724
725 // Verifies that metadata ops are wired up properly.
726 template <typename OpTy>
verifyOpMetadata(Operation * op,StringRef attributeName)727 static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
728 auto verifySymbolType = [op](Operation *symbolOp,
729 SymbolRefAttr symbolRef) -> LogicalResult {
730 if (!isa<OpTy>(symbolOp)) {
731 return op->emitOpError()
732 << "expected '" << symbolRef << "' to resolve to a "
733 << OpTy::getOperationName();
734 }
735 return success();
736 };
737
738 return verifySymbolAttribute(op, attributeName, verifySymbolType);
739 }
740
verifyMemoryOpMetadata(Operation * op)741 static LogicalResult verifyMemoryOpMetadata(Operation *op) {
742 // access_groups
743 if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
744 op, LLVMDialect::getAccessGroupsAttrName())))
745 return failure();
746
747 // alias_scopes
748 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
749 op, LLVMDialect::getAliasScopesAttrName())))
750 return failure();
751
752 // noalias_scopes
753 if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
754 op, LLVMDialect::getNoAliasScopesAttrName())))
755 return failure();
756
757 return success();
758 }
759
verify()760 LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
761
build(OpBuilder & builder,OperationState & result,Type t,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)762 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
763 Value addr, unsigned alignment, bool isVolatile,
764 bool isNonTemporal) {
765 result.addOperands(addr);
766 result.addTypes(t);
767 if (isVolatile)
768 result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
769 if (isNonTemporal)
770 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
771 if (alignment != 0)
772 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
773 }
774
print(OpAsmPrinter & p)775 void LoadOp::print(OpAsmPrinter &p) {
776 p << ' ';
777 if (getVolatile_())
778 p << "volatile ";
779 p << getAddr();
780 p.printOptionalAttrDict((*this)->getAttrs(),
781 {kVolatileAttrName, kElemTypeAttrName});
782 p << " : " << getAddr().getType();
783 if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
784 p << " -> " << getType();
785 }
786
787 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
788 // the resulting type if any, null type if opaque pointers are used, and None
789 // if the given type is not the pointer type.
getLoadStoreElementType(OpAsmParser & parser,Type type,SMLoc trailingTypeLoc)790 static Optional<Type> getLoadStoreElementType(OpAsmParser &parser, Type type,
791 SMLoc trailingTypeLoc) {
792 auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
793 if (!llvmTy) {
794 parser.emitError(trailingTypeLoc, "expected LLVM pointer type");
795 return llvm::None;
796 }
797 return llvmTy.getElementType();
798 }
799
800 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
801 // (`->` type)?
parse(OpAsmParser & parser,OperationState & result)802 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
803 OpAsmParser::UnresolvedOperand addr;
804 Type type;
805 SMLoc trailingTypeLoc;
806
807 if (succeeded(parser.parseOptionalKeyword("volatile")))
808 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
809
810 if (parser.parseOperand(addr) ||
811 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
812 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
813 parser.resolveOperand(addr, type, result.operands))
814 return failure();
815
816 Optional<Type> elemTy =
817 getLoadStoreElementType(parser, type, trailingTypeLoc);
818 if (!elemTy)
819 return failure();
820 if (*elemTy) {
821 result.addTypes(*elemTy);
822 return success();
823 }
824
825 Type trailingType;
826 if (parser.parseArrow() || parser.parseType(trailingType))
827 return failure();
828 result.addTypes(trailingType);
829 return success();
830 }
831
832 //===----------------------------------------------------------------------===//
833 // Builder, printer and parser for LLVM::StoreOp.
834 //===----------------------------------------------------------------------===//
835
verify()836 LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
837
build(OpBuilder & builder,OperationState & result,Value value,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)838 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
839 Value addr, unsigned alignment, bool isVolatile,
840 bool isNonTemporal) {
841 result.addOperands({value, addr});
842 result.addTypes({});
843 if (isVolatile)
844 result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
845 if (isNonTemporal)
846 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
847 if (alignment != 0)
848 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
849 }
850
print(OpAsmPrinter & p)851 void StoreOp::print(OpAsmPrinter &p) {
852 p << ' ';
853 if (getVolatile_())
854 p << "volatile ";
855 p << getValue() << ", " << getAddr();
856 p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName});
857 p << " : ";
858 if (getAddr().getType().cast<LLVMPointerType>().isOpaque())
859 p << getValue().getType() << ", ";
860 p << getAddr().getType();
861 }
862
863 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
864 // attribute-dict? `:` type (`,` type)?
parse(OpAsmParser & parser,OperationState & result)865 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
866 OpAsmParser::UnresolvedOperand addr, value;
867 Type type;
868 SMLoc trailingTypeLoc;
869
870 if (succeeded(parser.parseOptionalKeyword("volatile")))
871 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
872
873 if (parser.parseOperand(value) || parser.parseComma() ||
874 parser.parseOperand(addr) ||
875 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
876 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
877 return failure();
878
879 Type operandType;
880 if (succeeded(parser.parseOptionalComma())) {
881 operandType = type;
882 if (parser.parseType(type))
883 return failure();
884 } else {
885 Optional<Type> maybeOperandType =
886 getLoadStoreElementType(parser, type, trailingTypeLoc);
887 if (!maybeOperandType)
888 return failure();
889 operandType = *maybeOperandType;
890 }
891
892 if (parser.resolveOperand(value, operandType, result.operands) ||
893 parser.resolveOperand(addr, type, result.operands))
894 return failure();
895
896 return success();
897 }
898
899 ///===---------------------------------------------------------------------===//
900 /// LLVM::InvokeOp
901 ///===---------------------------------------------------------------------===//
902
getSuccessorOperands(unsigned index)903 SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
904 assert(index < getNumSuccessors() && "invalid successor index");
905 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
906 : getUnwindDestOperandsMutable());
907 }
908
getCallableForCallee()909 CallInterfaceCallable InvokeOp::getCallableForCallee() {
910 // Direct call.
911 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
912 return calleeAttr;
913 // Indirect call, callee Value is the first operand.
914 return getOperand(0);
915 }
916
getArgOperands()917 Operation::operand_range InvokeOp::getArgOperands() {
918 return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
919 }
920
verify()921 LogicalResult InvokeOp::verify() {
922 if (getNumResults() > 1)
923 return emitOpError("must have 0 or 1 result");
924
925 Block *unwindDest = getUnwindDest();
926 if (unwindDest->empty())
927 return emitError("must have at least one operation in unwind destination");
928
929 // In unwind destination, first operation must be LandingpadOp
930 if (!isa<LandingpadOp>(unwindDest->front()))
931 return emitError("first operation in unwind destination should be a "
932 "llvm.landingpad operation");
933
934 return success();
935 }
936
print(OpAsmPrinter & p)937 void InvokeOp::print(OpAsmPrinter &p) {
938 auto callee = getCallee();
939 bool isDirect = callee.has_value();
940
941 p << ' ';
942
943 // Either function name or pointer
944 if (isDirect)
945 p.printSymbolName(callee.value());
946 else
947 p << getOperand(0);
948
949 p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
950 p << " to ";
951 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
952 p << " unwind ";
953 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
954
955 p.printOptionalAttrDict((*this)->getAttrs(),
956 {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
957 p << " : ";
958 p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
959 getResultTypes());
960 }
961
962 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
963 /// `to` bb-id (`[` ssa-use-and-type-list `]`)?
964 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
965 /// attribute-dict? `:` function-type
parse(OpAsmParser & parser,OperationState & result)966 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
967 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
968 FunctionType funcType;
969 SymbolRefAttr funcAttr;
970 SMLoc trailingTypeLoc;
971 Block *normalDest, *unwindDest;
972 SmallVector<Value, 4> normalOperands, unwindOperands;
973 Builder &builder = parser.getBuilder();
974
975 // Parse an operand list that will, in practice, contain 0 or 1 operand. In
976 // case of an indirect call, there will be 1 operand before `(`. In case of a
977 // direct call, there will be no operands and the parser will stop at the
978 // function identifier without complaining.
979 if (parser.parseOperandList(operands))
980 return failure();
981 bool isDirect = operands.empty();
982
983 // Optionally parse a function identifier.
984 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
985 return failure();
986
987 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
988 parser.parseKeyword("to") ||
989 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
990 parser.parseKeyword("unwind") ||
991 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
992 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
993 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
994 return failure();
995
996 if (isDirect) {
997 // Make sure types match.
998 if (parser.resolveOperands(operands, funcType.getInputs(),
999 parser.getNameLoc(), result.operands))
1000 return failure();
1001 result.addTypes(funcType.getResults());
1002 } else {
1003 // Construct the LLVM IR Dialect function type that the first operand
1004 // should match.
1005 if (funcType.getNumResults() > 1)
1006 return parser.emitError(trailingTypeLoc,
1007 "expected function with 0 or 1 result");
1008
1009 Type llvmResultType;
1010 if (funcType.getNumResults() == 0) {
1011 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
1012 } else {
1013 llvmResultType = funcType.getResult(0);
1014 if (!isCompatibleType(llvmResultType))
1015 return parser.emitError(trailingTypeLoc,
1016 "expected result to have LLVM type");
1017 }
1018
1019 SmallVector<Type, 8> argTypes;
1020 argTypes.reserve(funcType.getNumInputs());
1021 for (Type ty : funcType.getInputs()) {
1022 if (isCompatibleType(ty))
1023 argTypes.push_back(ty);
1024 else
1025 return parser.emitError(trailingTypeLoc,
1026 "expected LLVM types as inputs");
1027 }
1028
1029 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
1030 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
1031
1032 auto funcArguments = llvm::makeArrayRef(operands).drop_front();
1033
1034 // Make sure that the first operand (indirect callee) matches the wrapped
1035 // LLVM IR function type, and that the types of the other call operands
1036 // match the types of the function arguments.
1037 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
1038 parser.resolveOperands(funcArguments, funcType.getInputs(),
1039 parser.getNameLoc(), result.operands))
1040 return failure();
1041
1042 result.addTypes(llvmResultType);
1043 }
1044 result.addSuccessors({normalDest, unwindDest});
1045 result.addOperands(normalOperands);
1046 result.addOperands(unwindOperands);
1047
1048 result.addAttribute(
1049 InvokeOp::getOperandSegmentSizeAttr(),
1050 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
1051 static_cast<int32_t>(normalOperands.size()),
1052 static_cast<int32_t>(unwindOperands.size())}));
1053 return success();
1054 }
1055
1056 ///===----------------------------------------------------------------------===//
1057 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1058 ///===----------------------------------------------------------------------===//
1059
verify()1060 LogicalResult LandingpadOp::verify() {
1061 Value value;
1062 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1063 if (!func.getPersonality())
1064 return emitError(
1065 "llvm.landingpad needs to be in a function with a personality");
1066 }
1067
1068 if (!getCleanup() && getOperands().empty())
1069 return emitError("landingpad instruction expects at least one clause or "
1070 "cleanup attribute");
1071
1072 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1073 value = getOperand(idx);
1074 bool isFilter = value.getType().isa<LLVMArrayType>();
1075 if (isFilter) {
1076 // FIXME: Verify filter clauses when arrays are appropriately handled
1077 } else {
1078 // catch - global addresses only.
1079 // Bitcast ops should have global addresses as their args.
1080 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1081 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1082 continue;
1083 return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1084 << "global addresses expected as operand to "
1085 "bitcast used in clauses for landingpad";
1086 }
1087 // NullOp and AddressOfOp allowed
1088 if (value.getDefiningOp<NullOp>())
1089 continue;
1090 if (value.getDefiningOp<AddressOfOp>())
1091 continue;
1092 return emitError("clause #")
1093 << idx << " is not a known constant - null, addressof, bitcast";
1094 }
1095 }
1096 return success();
1097 }
1098
print(OpAsmPrinter & p)1099 void LandingpadOp::print(OpAsmPrinter &p) {
1100 p << (getCleanup() ? " cleanup " : " ");
1101
1102 // Clauses
1103 for (auto value : getOperands()) {
1104 // Similar to llvm - if clause is an array type then it is filter
1105 // clause else catch clause
1106 bool isArrayTy = value.getType().isa<LLVMArrayType>();
1107 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1108 << value.getType() << ") ";
1109 }
1110
1111 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1112
1113 p << ": " << getType();
1114 }
1115
1116 /// <operation> ::= `llvm.landingpad` `cleanup`?
1117 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
parse(OpAsmParser & parser,OperationState & result)1118 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1119 // Check for cleanup
1120 if (succeeded(parser.parseOptionalKeyword("cleanup")))
1121 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1122
1123 // Parse clauses with types
1124 while (succeeded(parser.parseOptionalLParen()) &&
1125 (succeeded(parser.parseOptionalKeyword("filter")) ||
1126 succeeded(parser.parseOptionalKeyword("catch")))) {
1127 OpAsmParser::UnresolvedOperand operand;
1128 Type ty;
1129 if (parser.parseOperand(operand) || parser.parseColon() ||
1130 parser.parseType(ty) ||
1131 parser.resolveOperand(operand, ty, result.operands) ||
1132 parser.parseRParen())
1133 return failure();
1134 }
1135
1136 Type type;
1137 if (parser.parseColon() || parser.parseType(type))
1138 return failure();
1139
1140 result.addTypes(type);
1141 return success();
1142 }
1143
1144 //===----------------------------------------------------------------------===//
1145 // Verifying/Printing/parsing for LLVM::CallOp.
1146 //===----------------------------------------------------------------------===//
1147
getCallableForCallee()1148 CallInterfaceCallable CallOp::getCallableForCallee() {
1149 // Direct call.
1150 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1151 return calleeAttr;
1152 // Indirect call, callee Value is the first operand.
1153 return getOperand(0);
1154 }
1155
getArgOperands()1156 Operation::operand_range CallOp::getArgOperands() {
1157 return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
1158 }
1159
verify()1160 LogicalResult CallOp::verify() {
1161 if (getNumResults() > 1)
1162 return emitOpError("must have 0 or 1 result");
1163
1164 // Type for the callee, we'll get it differently depending if it is a direct
1165 // or indirect call.
1166 Type fnType;
1167
1168 bool isIndirect = false;
1169
1170 // If this is an indirect call, the callee attribute is missing.
1171 FlatSymbolRefAttr calleeName = getCalleeAttr();
1172 if (!calleeName) {
1173 isIndirect = true;
1174 if (!getNumOperands())
1175 return emitOpError(
1176 "must have either a `callee` attribute or at least an operand");
1177 auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
1178 if (!ptrType)
1179 return emitOpError("indirect call expects a pointer as callee: ")
1180 << ptrType;
1181 fnType = ptrType.getElementType();
1182 } else {
1183 Operation *callee =
1184 SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
1185 if (!callee)
1186 return emitOpError()
1187 << "'" << calleeName.getValue()
1188 << "' does not reference a symbol in the current scope";
1189 auto fn = dyn_cast<LLVMFuncOp>(callee);
1190 if (!fn)
1191 return emitOpError() << "'" << calleeName.getValue()
1192 << "' does not reference a valid LLVM function";
1193
1194 fnType = fn.getFunctionType();
1195 }
1196
1197 LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
1198 if (!funcType)
1199 return emitOpError("callee does not have a functional type: ") << fnType;
1200
1201 // Verify that the operand and result types match the callee.
1202
1203 if (!funcType.isVarArg() &&
1204 funcType.getNumParams() != (getNumOperands() - isIndirect))
1205 return emitOpError() << "incorrect number of operands ("
1206 << (getNumOperands() - isIndirect)
1207 << ") for callee (expecting: "
1208 << funcType.getNumParams() << ")";
1209
1210 if (funcType.getNumParams() > (getNumOperands() - isIndirect))
1211 return emitOpError() << "incorrect number of operands ("
1212 << (getNumOperands() - isIndirect)
1213 << ") for varargs callee (expecting at least: "
1214 << funcType.getNumParams() << ")";
1215
1216 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1217 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1218 return emitOpError() << "operand type mismatch for operand " << i << ": "
1219 << getOperand(i + isIndirect).getType()
1220 << " != " << funcType.getParamType(i);
1221
1222 if (getNumResults() == 0 &&
1223 !funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1224 return emitOpError() << "expected function call to produce a value";
1225
1226 if (getNumResults() != 0 &&
1227 funcType.getReturnType().isa<LLVM::LLVMVoidType>())
1228 return emitOpError()
1229 << "calling function with void result must not produce values";
1230
1231 if (getNumResults() > 1)
1232 return emitOpError()
1233 << "expected LLVM function call to produce 0 or 1 result";
1234
1235 if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
1236 return emitOpError() << "result type mismatch: " << getResult(0).getType()
1237 << " != " << funcType.getReturnType();
1238
1239 return success();
1240 }
1241
print(OpAsmPrinter & p)1242 void CallOp::print(OpAsmPrinter &p) {
1243 auto callee = getCallee();
1244 bool isDirect = callee.has_value();
1245
1246 // Print the direct callee if present as a function attribute, or an indirect
1247 // callee (first operand) otherwise.
1248 p << ' ';
1249 if (isDirect)
1250 p.printSymbolName(callee.value());
1251 else
1252 p << getOperand(0);
1253
1254 auto args = getOperands().drop_front(isDirect ? 0 : 1);
1255 p << '(' << args << ')';
1256 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"callee"});
1257
1258 // Reconstruct the function MLIR function type from operand and result types.
1259 p << " : ";
1260 p.printFunctionalType(args.getTypes(), getResultTypes());
1261 }
1262
1263 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
1264 // attribute-dict? `:` function-type
parse(OpAsmParser & parser,OperationState & result)1265 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1266 SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
1267 Type type;
1268 SymbolRefAttr funcAttr;
1269 SMLoc trailingTypeLoc;
1270
1271 // Parse an operand list that will, in practice, contain 0 or 1 operand. In
1272 // case of an indirect call, there will be 1 operand before `(`. In case of a
1273 // direct call, there will be no operands and the parser will stop at the
1274 // function identifier without complaining.
1275 if (parser.parseOperandList(operands))
1276 return failure();
1277 bool isDirect = operands.empty();
1278
1279 // Optionally parse a function identifier.
1280 if (isDirect)
1281 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1282 return failure();
1283
1284 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1285 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1286 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
1287 return failure();
1288
1289 auto funcType = type.dyn_cast<FunctionType>();
1290 if (!funcType)
1291 return parser.emitError(trailingTypeLoc, "expected function type");
1292 if (funcType.getNumResults() > 1)
1293 return parser.emitError(trailingTypeLoc,
1294 "expected function with 0 or 1 result");
1295 if (isDirect) {
1296 // Make sure types match.
1297 if (parser.resolveOperands(operands, funcType.getInputs(),
1298 parser.getNameLoc(), result.operands))
1299 return failure();
1300 if (funcType.getNumResults() != 0 &&
1301 !funcType.getResult(0).isa<LLVM::LLVMVoidType>())
1302 result.addTypes(funcType.getResults());
1303 } else {
1304 Builder &builder = parser.getBuilder();
1305 Type llvmResultType;
1306 if (funcType.getNumResults() == 0) {
1307 llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
1308 } else {
1309 llvmResultType = funcType.getResult(0);
1310 if (!isCompatibleType(llvmResultType))
1311 return parser.emitError(trailingTypeLoc,
1312 "expected result to have LLVM type");
1313 }
1314
1315 SmallVector<Type, 8> argTypes;
1316 argTypes.reserve(funcType.getNumInputs());
1317 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
1318 auto argType = funcType.getInput(i);
1319 if (!isCompatibleType(argType))
1320 return parser.emitError(trailingTypeLoc,
1321 "expected LLVM types as inputs");
1322 argTypes.push_back(argType);
1323 }
1324 auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes);
1325 auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType);
1326
1327 auto funcArguments =
1328 ArrayRef<OpAsmParser::UnresolvedOperand>(operands).drop_front();
1329
1330 // Make sure that the first operand (indirect callee) matches the wrapped
1331 // LLVM IR function type, and that the types of the other call operands
1332 // match the types of the function arguments.
1333 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
1334 parser.resolveOperands(funcArguments, funcType.getInputs(),
1335 parser.getNameLoc(), result.operands))
1336 return failure();
1337
1338 if (!llvmResultType.isa<LLVM::LLVMVoidType>())
1339 result.addTypes(llvmResultType);
1340 }
1341
1342 return success();
1343 }
1344
1345 //===----------------------------------------------------------------------===//
1346 // Printing/parsing for LLVM::ExtractElementOp.
1347 //===----------------------------------------------------------------------===//
1348 // Expects vector to be of wrapped LLVM vector type and position to be of
1349 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value vector,Value position,ArrayRef<NamedAttribute> attrs)1350 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
1351 Value vector, Value position,
1352 ArrayRef<NamedAttribute> attrs) {
1353 auto vectorType = vector.getType();
1354 auto llvmType = LLVM::getVectorElementType(vectorType);
1355 build(b, result, llvmType, vector, position);
1356 result.addAttributes(attrs);
1357 }
1358
print(OpAsmPrinter & p)1359 void ExtractElementOp::print(OpAsmPrinter &p) {
1360 p << ' ' << getVector() << "[" << getPosition() << " : "
1361 << getPosition().getType() << "]";
1362 p.printOptionalAttrDict((*this)->getAttrs());
1363 p << " : " << getVector().getType();
1364 }
1365
1366 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
1367 // attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1368 ParseResult ExtractElementOp::parse(OpAsmParser &parser,
1369 OperationState &result) {
1370 SMLoc loc;
1371 OpAsmParser::UnresolvedOperand vector, position;
1372 Type type, positionType;
1373 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
1374 parser.parseLSquare() || parser.parseOperand(position) ||
1375 parser.parseColonType(positionType) || parser.parseRSquare() ||
1376 parser.parseOptionalAttrDict(result.attributes) ||
1377 parser.parseColonType(type) ||
1378 parser.resolveOperand(vector, type, result.operands) ||
1379 parser.resolveOperand(position, positionType, result.operands))
1380 return failure();
1381 if (!LLVM::isCompatibleVectorType(type))
1382 return parser.emitError(
1383 loc, "expected LLVM dialect-compatible vector type for operand #1");
1384 result.addTypes(LLVM::getVectorElementType(type));
1385 return success();
1386 }
1387
verify()1388 LogicalResult ExtractElementOp::verify() {
1389 Type vectorType = getVector().getType();
1390 if (!LLVM::isCompatibleVectorType(vectorType))
1391 return emitOpError("expected LLVM dialect-compatible vector type for "
1392 "operand #1, got")
1393 << vectorType;
1394 Type valueType = LLVM::getVectorElementType(vectorType);
1395 if (valueType != getRes().getType())
1396 return emitOpError() << "Type mismatch: extracting from " << vectorType
1397 << " should produce " << valueType
1398 << " but this op returns " << getRes().getType();
1399 return success();
1400 }
1401
1402 //===----------------------------------------------------------------------===//
1403 // Printing/parsing for LLVM::ExtractValueOp.
1404 //===----------------------------------------------------------------------===//
1405
print(OpAsmPrinter & p)1406 void ExtractValueOp::print(OpAsmPrinter &p) {
1407 p << ' ' << getContainer() << getPosition();
1408 p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1409 p << " : " << getContainer().getType();
1410 }
1411
1412 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1413 // `containerType`. Position is an integer array attribute where each value
1414 // is a zero-based position of the element in the aggregate type. Return the
1415 // resulting type wrapped in MLIR, or nullptr on error.
getInsertExtractValueElementType(OpAsmParser & parser,Type containerType,ArrayAttr positionAttr,SMLoc attributeLoc,SMLoc typeLoc)1416 static Type getInsertExtractValueElementType(OpAsmParser &parser,
1417 Type containerType,
1418 ArrayAttr positionAttr,
1419 SMLoc attributeLoc,
1420 SMLoc typeLoc) {
1421 Type llvmType = containerType;
1422 if (!isCompatibleType(containerType))
1423 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
1424
1425 // Infer the element type from the structure type: iteratively step inside the
1426 // type by taking the element type, indexed by the position attribute for
1427 // structures. Check the position index before accessing, it is supposed to
1428 // be in bounds.
1429 for (Attribute subAttr : positionAttr) {
1430 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1431 if (!positionElementAttr)
1432 return parser.emitError(attributeLoc,
1433 "expected an array of integer literals"),
1434 nullptr;
1435 int position = positionElementAttr.getInt();
1436 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1437 if (position < 0 ||
1438 static_cast<unsigned>(position) >= arrayType.getNumElements())
1439 return parser.emitError(attributeLoc, "position out of bounds"),
1440 nullptr;
1441 llvmType = arrayType.getElementType();
1442 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1443 if (position < 0 ||
1444 static_cast<unsigned>(position) >= structType.getBody().size())
1445 return parser.emitError(attributeLoc, "position out of bounds"),
1446 nullptr;
1447 llvmType = structType.getBody()[position];
1448 } else {
1449 return parser.emitError(typeLoc, "expected LLVM IR structure/array type"),
1450 nullptr;
1451 }
1452 }
1453 return llvmType;
1454 }
1455
1456 // Extract the type at `position` in the wrapped LLVM IR aggregate type
1457 // `containerType`. Returns null on failure.
getInsertExtractValueElementType(Type containerType,ArrayAttr positionAttr,Operation * op)1458 static Type getInsertExtractValueElementType(Type containerType,
1459 ArrayAttr positionAttr,
1460 Operation *op) {
1461 Type llvmType = containerType;
1462 if (!isCompatibleType(containerType)) {
1463 op->emitError("expected LLVM IR Dialect type, got ") << containerType;
1464 return {};
1465 }
1466
1467 // Infer the element type from the structure type: iteratively step inside the
1468 // type by taking the element type, indexed by the position attribute for
1469 // structures. Check the position index before accessing, it is supposed to
1470 // be in bounds.
1471 for (Attribute subAttr : positionAttr) {
1472 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
1473 if (!positionElementAttr) {
1474 op->emitOpError("expected an array of integer literals, got: ")
1475 << subAttr;
1476 return {};
1477 }
1478 int position = positionElementAttr.getInt();
1479 if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
1480 if (position < 0 ||
1481 static_cast<unsigned>(position) >= arrayType.getNumElements()) {
1482 op->emitOpError("position out of bounds: ") << position;
1483 return {};
1484 }
1485 llvmType = arrayType.getElementType();
1486 } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
1487 if (position < 0 ||
1488 static_cast<unsigned>(position) >= structType.getBody().size()) {
1489 op->emitOpError("position out of bounds") << position;
1490 return {};
1491 }
1492 llvmType = structType.getBody()[position];
1493 } else {
1494 op->emitOpError("expected LLVM IR structure/array type, got: ")
1495 << llvmType;
1496 return {};
1497 }
1498 }
1499 return llvmType;
1500 }
1501
1502 // <operation> ::= `llvm.extractvalue` ssa-use
1503 // `[` integer-literal (`,` integer-literal)* `]`
1504 // attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1505 ParseResult ExtractValueOp::parse(OpAsmParser &parser, OperationState &result) {
1506 OpAsmParser::UnresolvedOperand container;
1507 Type containerType;
1508 ArrayAttr positionAttr;
1509 SMLoc attributeLoc, trailingTypeLoc;
1510
1511 if (parser.parseOperand(container) ||
1512 parser.getCurrentLocation(&attributeLoc) ||
1513 parser.parseAttribute(positionAttr, "position", result.attributes) ||
1514 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1515 parser.getCurrentLocation(&trailingTypeLoc) ||
1516 parser.parseType(containerType) ||
1517 parser.resolveOperand(container, containerType, result.operands))
1518 return failure();
1519
1520 auto elementType = getInsertExtractValueElementType(
1521 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1522 if (!elementType)
1523 return failure();
1524
1525 result.addTypes(elementType);
1526 return success();
1527 }
1528
fold(ArrayRef<Attribute> operands)1529 OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
1530 auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1531 OpFoldResult result = {};
1532 while (insertValueOp) {
1533 if (getPosition() == insertValueOp.getPosition())
1534 return insertValueOp.getValue();
1535 unsigned min =
1536 std::min(getPosition().size(), insertValueOp.getPosition().size());
1537 // If one is fully prefix of the other, stop propagating back as it will
1538 // miss dependencies. For instance, %3 should not fold to %f0 in the
1539 // following example:
1540 // ```
1541 // %1 = llvm.insertvalue %f0, %0[0, 0] :
1542 // !llvm.array<4 x !llvm.array<4xf32>>
1543 // %2 = llvm.insertvalue %arr, %1[0] :
1544 // !llvm.array<4 x !llvm.array<4xf32>>
1545 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>>
1546 // ```
1547 if (getPosition().getValue().take_front(min) ==
1548 insertValueOp.getPosition().getValue().take_front(min))
1549 return result;
1550
1551 // If neither a prefix, nor the exact position, we can extract out of the
1552 // value being inserted into. Moreover, we can try again if that operand
1553 // is itself an insertvalue expression.
1554 getContainerMutable().assign(insertValueOp.getContainer());
1555 result = getResult();
1556 insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1557 }
1558 return result;
1559 }
1560
verify()1561 LogicalResult ExtractValueOp::verify() {
1562 Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1563 getPositionAttr(), *this);
1564 if (!valueType)
1565 return failure();
1566
1567 if (getRes().getType() != valueType)
1568 return emitOpError() << "Type mismatch: extracting from "
1569 << getContainer().getType() << " should produce "
1570 << valueType << " but this op returns "
1571 << getRes().getType();
1572 return success();
1573 }
1574
1575 //===----------------------------------------------------------------------===//
1576 // Printing/parsing for LLVM::InsertElementOp.
1577 //===----------------------------------------------------------------------===//
1578
print(OpAsmPrinter & p)1579 void InsertElementOp::print(OpAsmPrinter &p) {
1580 p << ' ' << getValue() << ", " << getVector() << "[" << getPosition() << " : "
1581 << getPosition().getType() << "]";
1582 p.printOptionalAttrDict((*this)->getAttrs());
1583 p << " : " << getVector().getType();
1584 }
1585
1586 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
1587 // attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1588 ParseResult InsertElementOp::parse(OpAsmParser &parser,
1589 OperationState &result) {
1590 SMLoc loc;
1591 OpAsmParser::UnresolvedOperand vector, value, position;
1592 Type vectorType, positionType;
1593 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
1594 parser.parseComma() || parser.parseOperand(vector) ||
1595 parser.parseLSquare() || parser.parseOperand(position) ||
1596 parser.parseColonType(positionType) || parser.parseRSquare() ||
1597 parser.parseOptionalAttrDict(result.attributes) ||
1598 parser.parseColonType(vectorType))
1599 return failure();
1600
1601 if (!LLVM::isCompatibleVectorType(vectorType))
1602 return parser.emitError(
1603 loc, "expected LLVM dialect-compatible vector type for operand #1");
1604 Type valueType = LLVM::getVectorElementType(vectorType);
1605 if (!valueType)
1606 return failure();
1607
1608 if (parser.resolveOperand(vector, vectorType, result.operands) ||
1609 parser.resolveOperand(value, valueType, result.operands) ||
1610 parser.resolveOperand(position, positionType, result.operands))
1611 return failure();
1612
1613 result.addTypes(vectorType);
1614 return success();
1615 }
1616
verify()1617 LogicalResult InsertElementOp::verify() {
1618 Type valueType = LLVM::getVectorElementType(getVector().getType());
1619 if (valueType != getValue().getType())
1620 return emitOpError() << "Type mismatch: cannot insert "
1621 << getValue().getType() << " into "
1622 << getVector().getType();
1623 return success();
1624 }
1625
1626 //===----------------------------------------------------------------------===//
1627 // Printing/parsing for LLVM::InsertValueOp.
1628 //===----------------------------------------------------------------------===//
1629
print(OpAsmPrinter & p)1630 void InsertValueOp::print(OpAsmPrinter &p) {
1631 p << ' ' << getValue() << ", " << getContainer() << getPosition();
1632 p.printOptionalAttrDict((*this)->getAttrs(), {"position"});
1633 p << " : " << getContainer().getType();
1634 }
1635
1636 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
1637 // `[` integer-literal (`,` integer-literal)* `]`
1638 // attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)1639 ParseResult InsertValueOp::parse(OpAsmParser &parser, OperationState &result) {
1640 OpAsmParser::UnresolvedOperand container, value;
1641 Type containerType;
1642 ArrayAttr positionAttr;
1643 SMLoc attributeLoc, trailingTypeLoc;
1644
1645 if (parser.parseOperand(value) || parser.parseComma() ||
1646 parser.parseOperand(container) ||
1647 parser.getCurrentLocation(&attributeLoc) ||
1648 parser.parseAttribute(positionAttr, "position", result.attributes) ||
1649 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1650 parser.getCurrentLocation(&trailingTypeLoc) ||
1651 parser.parseType(containerType))
1652 return failure();
1653
1654 auto valueType = getInsertExtractValueElementType(
1655 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
1656 if (!valueType)
1657 return failure();
1658
1659 if (parser.resolveOperand(container, containerType, result.operands) ||
1660 parser.resolveOperand(value, valueType, result.operands))
1661 return failure();
1662
1663 result.addTypes(containerType);
1664 return success();
1665 }
1666
verify()1667 LogicalResult InsertValueOp::verify() {
1668 Type valueType = getInsertExtractValueElementType(getContainer().getType(),
1669 getPositionAttr(), *this);
1670 if (!valueType)
1671 return failure();
1672
1673 if (getValue().getType() != valueType)
1674 return emitOpError() << "Type mismatch: cannot insert "
1675 << getValue().getType() << " into "
1676 << getContainer().getType();
1677
1678 return success();
1679 }
1680
1681 //===----------------------------------------------------------------------===//
1682 // Printing, parsing and verification for LLVM::ReturnOp.
1683 //===----------------------------------------------------------------------===//
1684
verify()1685 LogicalResult ReturnOp::verify() {
1686 if (getNumOperands() > 1)
1687 return emitOpError("expected at most 1 operand");
1688
1689 if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
1690 Type expectedType = parent.getFunctionType().getReturnType();
1691 if (expectedType.isa<LLVMVoidType>()) {
1692 if (getNumOperands() == 0)
1693 return success();
1694 InFlightDiagnostic diag = emitOpError("expected no operands");
1695 diag.attachNote(parent->getLoc()) << "when returning from function";
1696 return diag;
1697 }
1698 if (getNumOperands() == 0) {
1699 if (expectedType.isa<LLVMVoidType>())
1700 return success();
1701 InFlightDiagnostic diag = emitOpError("expected 1 operand");
1702 diag.attachNote(parent->getLoc()) << "when returning from function";
1703 return diag;
1704 }
1705 if (expectedType != getOperand(0).getType()) {
1706 InFlightDiagnostic diag = emitOpError("mismatching result types");
1707 diag.attachNote(parent->getLoc()) << "when returning from function";
1708 return diag;
1709 }
1710 }
1711 return success();
1712 }
1713
1714 //===----------------------------------------------------------------------===//
1715 // ResumeOp
1716 //===----------------------------------------------------------------------===//
1717
verify()1718 LogicalResult ResumeOp::verify() {
1719 if (!getValue().getDefiningOp<LandingpadOp>())
1720 return emitOpError("expects landingpad value as operand");
1721 // No check for personality of function - landingpad op verifies it.
1722 return success();
1723 }
1724
1725 //===----------------------------------------------------------------------===//
1726 // Verifier for LLVM::AddressOfOp.
1727 //===----------------------------------------------------------------------===//
1728
1729 template <typename OpTy>
lookupSymbolInModule(Operation * parent,StringRef name)1730 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
1731 Operation *module = parent;
1732 while (module && !satisfiesLLVMModule(module))
1733 module = module->getParentOp();
1734 assert(module && "unexpected operation outside of a module");
1735 return dyn_cast_or_null<OpTy>(
1736 mlir::SymbolTable::lookupSymbolIn(module, name));
1737 }
1738
getGlobal()1739 GlobalOp AddressOfOp::getGlobal() {
1740 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
1741 getGlobalName());
1742 }
1743
getFunction()1744 LLVMFuncOp AddressOfOp::getFunction() {
1745 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
1746 getGlobalName());
1747 }
1748
verify()1749 LogicalResult AddressOfOp::verify() {
1750 auto global = getGlobal();
1751 auto function = getFunction();
1752 if (!global && !function)
1753 return emitOpError(
1754 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1755
1756 LLVMPointerType type = getType();
1757 if (global && global.getAddrSpace() != type.getAddressSpace())
1758 return emitOpError("pointer address space must match address space of the "
1759 "referenced global");
1760
1761 if (type.isOpaque())
1762 return success();
1763
1764 if (global && type.getElementType() != global.getType())
1765 return emitOpError(
1766 "the type must be a pointer to the type of the referenced global");
1767
1768 if (function && type.getElementType() != function.getFunctionType())
1769 return emitOpError(
1770 "the type must be a pointer to the type of the referenced function");
1771
1772 return success();
1773 }
1774
1775 //===----------------------------------------------------------------------===//
1776 // Builder, printer and verifier for LLVM::GlobalOp.
1777 //===----------------------------------------------------------------------===//
1778
build(OpBuilder & builder,OperationState & result,Type type,bool isConstant,Linkage linkage,StringRef name,Attribute value,uint64_t alignment,unsigned addrSpace,bool dsoLocal,bool threadLocal,ArrayRef<NamedAttribute> attrs)1779 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
1780 bool isConstant, Linkage linkage, StringRef name,
1781 Attribute value, uint64_t alignment, unsigned addrSpace,
1782 bool dsoLocal, bool threadLocal,
1783 ArrayRef<NamedAttribute> attrs) {
1784 result.addAttribute(getSymNameAttrName(result.name),
1785 builder.getStringAttr(name));
1786 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
1787 if (isConstant)
1788 result.addAttribute(getConstantAttrName(result.name),
1789 builder.getUnitAttr());
1790 if (value)
1791 result.addAttribute(getValueAttrName(result.name), value);
1792 if (dsoLocal)
1793 result.addAttribute(getDsoLocalAttrName(result.name),
1794 builder.getUnitAttr());
1795 if (threadLocal)
1796 result.addAttribute(getThreadLocal_AttrName(result.name),
1797 builder.getUnitAttr());
1798
1799 // Only add an alignment attribute if the "alignment" input
1800 // is different from 0. The value must also be a power of two, but
1801 // this is tested in GlobalOp::verify, not here.
1802 if (alignment != 0)
1803 result.addAttribute(getAlignmentAttrName(result.name),
1804 builder.getI64IntegerAttr(alignment));
1805
1806 result.addAttribute(getLinkageAttrName(result.name),
1807 LinkageAttr::get(builder.getContext(), linkage));
1808 if (addrSpace != 0)
1809 result.addAttribute(getAddrSpaceAttrName(result.name),
1810 builder.getI32IntegerAttr(addrSpace));
1811 result.attributes.append(attrs.begin(), attrs.end());
1812 result.addRegion();
1813 }
1814
print(OpAsmPrinter & p)1815 void GlobalOp::print(OpAsmPrinter &p) {
1816 p << ' ' << stringifyLinkage(getLinkage()) << ' ';
1817 if (auto unnamedAddr = getUnnamedAddr()) {
1818 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
1819 if (!str.empty())
1820 p << str << ' ';
1821 }
1822 if (getThreadLocal_())
1823 p << "thread_local ";
1824 if (getConstant())
1825 p << "constant ";
1826 p.printSymbolName(getSymName());
1827 p << '(';
1828 if (auto value = getValueOrNull())
1829 p.printAttribute(value);
1830 p << ')';
1831 // Note that the alignment attribute is printed using the
1832 // default syntax here, even though it is an inherent attribute
1833 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
1834 p.printOptionalAttrDict(
1835 (*this)->getAttrs(),
1836 {SymbolTable::getSymbolAttrName(), getGlobalTypeAttrName(),
1837 getConstantAttrName(), getValueAttrName(), getLinkageAttrName(),
1838 getUnnamedAddrAttrName(), getThreadLocal_AttrName()});
1839
1840 // Print the trailing type unless it's a string global.
1841 if (getValueOrNull().dyn_cast_or_null<StringAttr>())
1842 return;
1843 p << " : " << getType();
1844
1845 Region &initializer = getInitializerRegion();
1846 if (!initializer.empty()) {
1847 p << ' ';
1848 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1849 }
1850 }
1851
1852 // Parses one of the keywords provided in the list `keywords` and returns the
1853 // position of the parsed keyword in the list. If none of the keywords from the
1854 // list is parsed, returns -1.
parseOptionalKeywordAlternative(OpAsmParser & parser,ArrayRef<StringRef> keywords)1855 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1856 ArrayRef<StringRef> keywords) {
1857 for (const auto &en : llvm::enumerate(keywords)) {
1858 if (succeeded(parser.parseOptionalKeyword(en.value())))
1859 return en.index();
1860 }
1861 return -1;
1862 }
1863
1864 namespace {
1865 template <typename Ty>
1866 struct EnumTraits {};
1867
1868 #define REGISTER_ENUM_TYPE(Ty) \
1869 template <> \
1870 struct EnumTraits<Ty> { \
1871 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
1872 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
1873 }
1874
1875 REGISTER_ENUM_TYPE(Linkage);
1876 REGISTER_ENUM_TYPE(UnnamedAddr);
1877 REGISTER_ENUM_TYPE(CConv);
1878 } // namespace
1879
1880 /// Parse an enum from the keyword, or default to the provided default value.
1881 /// The return type is the enum type by default, unless overriden with the
1882 /// second template argument.
1883 template <typename EnumTy, typename RetTy = EnumTy>
parseOptionalLLVMKeyword(OpAsmParser & parser,OperationState & result,EnumTy defaultValue)1884 static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
1885 OperationState &result,
1886 EnumTy defaultValue) {
1887 SmallVector<StringRef, 10> names;
1888 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
1889 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1890
1891 int index = parseOptionalKeywordAlternative(parser, names);
1892 if (index == -1)
1893 return static_cast<RetTy>(defaultValue);
1894 return static_cast<RetTy>(index);
1895 }
1896
1897 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1898 // `(` attribute? `)` align? attribute-list? (`:` type)? region?
1899 // align ::= `align` `=` UINT64
1900 //
1901 // The type can be omitted for string attributes, in which case it will be
1902 // inferred from the value of the string as [strlen(value) x i8].
parse(OpAsmParser & parser,OperationState & result)1903 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
1904 MLIRContext *ctx = parser.getContext();
1905 // Parse optional linkage, default to External.
1906 result.addAttribute(getLinkageAttrName(result.name),
1907 LLVM::LinkageAttr::get(
1908 ctx, parseOptionalLLVMKeyword<Linkage>(
1909 parser, result, LLVM::Linkage::External)));
1910
1911 if (succeeded(parser.parseOptionalKeyword("thread_local")))
1912 result.addAttribute(getThreadLocal_AttrName(result.name),
1913 parser.getBuilder().getUnitAttr());
1914
1915 // Parse optional UnnamedAddr, default to None.
1916 result.addAttribute(getUnnamedAddrAttrName(result.name),
1917 parser.getBuilder().getI64IntegerAttr(
1918 parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
1919 parser, result, LLVM::UnnamedAddr::None)));
1920
1921 if (succeeded(parser.parseOptionalKeyword("constant")))
1922 result.addAttribute(getConstantAttrName(result.name),
1923 parser.getBuilder().getUnitAttr());
1924
1925 StringAttr name;
1926 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
1927 result.attributes) ||
1928 parser.parseLParen())
1929 return failure();
1930
1931 Attribute value;
1932 if (parser.parseOptionalRParen()) {
1933 if (parser.parseAttribute(value, getValueAttrName(result.name),
1934 result.attributes) ||
1935 parser.parseRParen())
1936 return failure();
1937 }
1938
1939 SmallVector<Type, 1> types;
1940 if (parser.parseOptionalAttrDict(result.attributes) ||
1941 parser.parseOptionalColonTypeList(types))
1942 return failure();
1943
1944 if (types.size() > 1)
1945 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1946
1947 Region &initRegion = *result.addRegion();
1948 if (types.empty()) {
1949 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1950 MLIRContext *context = parser.getContext();
1951 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
1952 strAttr.getValue().size());
1953 types.push_back(arrayType);
1954 } else {
1955 return parser.emitError(parser.getNameLoc(),
1956 "type can only be omitted for string globals");
1957 }
1958 } else {
1959 OptionalParseResult parseResult =
1960 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1961 /*argTypes=*/{});
1962 if (parseResult.hasValue() && failed(*parseResult))
1963 return failure();
1964 }
1965
1966 result.addAttribute(getGlobalTypeAttrName(result.name),
1967 TypeAttr::get(types[0]));
1968 return success();
1969 }
1970
isZeroAttribute(Attribute value)1971 static bool isZeroAttribute(Attribute value) {
1972 if (auto intValue = value.dyn_cast<IntegerAttr>())
1973 return intValue.getValue().isNullValue();
1974 if (auto fpValue = value.dyn_cast<FloatAttr>())
1975 return fpValue.getValue().isZero();
1976 if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
1977 return isZeroAttribute(splatValue.getSplatValue<Attribute>());
1978 if (auto elementsValue = value.dyn_cast<ElementsAttr>())
1979 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
1980 if (auto arrayValue = value.dyn_cast<ArrayAttr>())
1981 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
1982 return false;
1983 }
1984
verify()1985 LogicalResult GlobalOp::verify() {
1986 if (!LLVMPointerType::isValidElementType(getType()))
1987 return emitOpError(
1988 "expects type to be a valid element type for an LLVM pointer");
1989 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
1990 return emitOpError("must appear at the module level");
1991
1992 if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1993 auto type = getType().dyn_cast<LLVMArrayType>();
1994 IntegerType elementType =
1995 type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
1996 if (!elementType || elementType.getWidth() != 8 ||
1997 type.getNumElements() != strAttr.getValue().size())
1998 return emitOpError(
1999 "requires an i8 array type of the length equal to that of the string "
2000 "attribute");
2001 }
2002
2003 if (getLinkage() == Linkage::Common) {
2004 if (Attribute value = getValueOrNull()) {
2005 if (!isZeroAttribute(value)) {
2006 return emitOpError()
2007 << "expected zero value for '"
2008 << stringifyLinkage(Linkage::Common) << "' linkage";
2009 }
2010 }
2011 }
2012
2013 if (getLinkage() == Linkage::Appending) {
2014 if (!getType().isa<LLVMArrayType>()) {
2015 return emitOpError() << "expected array type for '"
2016 << stringifyLinkage(Linkage::Appending)
2017 << "' linkage";
2018 }
2019 }
2020
2021 Optional<uint64_t> alignAttr = getAlignment();
2022 if (alignAttr.has_value()) {
2023 uint64_t value = alignAttr.value();
2024 if (!llvm::isPowerOf2_64(value))
2025 return emitError() << "alignment attribute is not a power of 2";
2026 }
2027
2028 return success();
2029 }
2030
verifyRegions()2031 LogicalResult GlobalOp::verifyRegions() {
2032 if (Block *b = getInitializerBlock()) {
2033 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2034 if (ret.operand_type_begin() == ret.operand_type_end())
2035 return emitOpError("initializer region cannot return void");
2036 if (*ret.operand_type_begin() != getType())
2037 return emitOpError("initializer region type ")
2038 << *ret.operand_type_begin() << " does not match global type "
2039 << getType();
2040
2041 for (Operation &op : *b) {
2042 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2043 if (!iface || !iface.hasNoEffect())
2044 return op.emitError()
2045 << "ops with side effects not allowed in global initializers";
2046 }
2047
2048 if (getValueOrNull())
2049 return emitOpError("cannot have both initializer value and region");
2050 }
2051
2052 return success();
2053 }
2054
2055 //===----------------------------------------------------------------------===//
2056 // LLVM::GlobalCtorsOp
2057 //===----------------------------------------------------------------------===//
2058
2059 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)2060 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2061 for (Attribute ctor : getCtors()) {
2062 if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this,
2063 symbolTable)))
2064 return failure();
2065 }
2066 return success();
2067 }
2068
verify()2069 LogicalResult GlobalCtorsOp::verify() {
2070 if (getCtors().size() != getPriorities().size())
2071 return emitError(
2072 "mismatch between the number of ctors and the number of priorities");
2073 return success();
2074 }
2075
2076 //===----------------------------------------------------------------------===//
2077 // LLVM::GlobalDtorsOp
2078 //===----------------------------------------------------------------------===//
2079
2080 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)2081 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2082 for (Attribute dtor : getDtors()) {
2083 if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this,
2084 symbolTable)))
2085 return failure();
2086 }
2087 return success();
2088 }
2089
verify()2090 LogicalResult GlobalDtorsOp::verify() {
2091 if (getDtors().size() != getPriorities().size())
2092 return emitError(
2093 "mismatch between the number of dtors and the number of priorities");
2094 return success();
2095 }
2096
2097 //===----------------------------------------------------------------------===//
2098 // Printing/parsing for LLVM::ShuffleVectorOp.
2099 //===----------------------------------------------------------------------===//
2100 // Expects vector to be of wrapped LLVM vector type and position to be of
2101 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value v1,Value v2,ArrayAttr mask,ArrayRef<NamedAttribute> attrs)2102 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
2103 Value v1, Value v2, ArrayAttr mask,
2104 ArrayRef<NamedAttribute> attrs) {
2105 auto containerType = v1.getType();
2106 auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
2107 mask.size(),
2108 LLVM::isScalableVectorType(containerType));
2109 build(b, result, vType, v1, v2, mask);
2110 result.addAttributes(attrs);
2111 }
2112
print(OpAsmPrinter & p)2113 void ShuffleVectorOp::print(OpAsmPrinter &p) {
2114 p << ' ' << getV1() << ", " << getV2() << " " << getMask();
2115 p.printOptionalAttrDict((*this)->getAttrs(), {"mask"});
2116 p << " : " << getV1().getType() << ", " << getV2().getType();
2117 }
2118
2119 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
2120 // `[` integer-literal (`,` integer-literal)* `]`
2121 // attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)2122 ParseResult ShuffleVectorOp::parse(OpAsmParser &parser,
2123 OperationState &result) {
2124 SMLoc loc;
2125 OpAsmParser::UnresolvedOperand v1, v2;
2126 ArrayAttr maskAttr;
2127 Type typeV1, typeV2;
2128 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
2129 parser.parseComma() || parser.parseOperand(v2) ||
2130 parser.parseAttribute(maskAttr, "mask", result.attributes) ||
2131 parser.parseOptionalAttrDict(result.attributes) ||
2132 parser.parseColonType(typeV1) || parser.parseComma() ||
2133 parser.parseType(typeV2) ||
2134 parser.resolveOperand(v1, typeV1, result.operands) ||
2135 parser.resolveOperand(v2, typeV2, result.operands))
2136 return failure();
2137 if (!LLVM::isCompatibleVectorType(typeV1))
2138 return parser.emitError(
2139 loc, "expected LLVM IR dialect vector type for operand #1");
2140 auto vType =
2141 LLVM::getVectorType(LLVM::getVectorElementType(typeV1), maskAttr.size(),
2142 LLVM::isScalableVectorType(typeV1));
2143 result.addTypes(vType);
2144 return success();
2145 }
2146
verify()2147 LogicalResult ShuffleVectorOp::verify() {
2148 Type type1 = getV1().getType();
2149 Type type2 = getV2().getType();
2150 if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
2151 return emitOpError("expected matching LLVM IR Dialect element types");
2152 if (LLVM::isScalableVectorType(type1))
2153 if (llvm::any_of(getMask(), [](Attribute attr) {
2154 return attr.cast<IntegerAttr>().getInt() != 0;
2155 }))
2156 return emitOpError("expected a splat operation for scalable vectors");
2157 return success();
2158 }
2159
2160 //===----------------------------------------------------------------------===//
2161 // Implementations for LLVM::LLVMFuncOp.
2162 //===----------------------------------------------------------------------===//
2163
2164 // Add the entry block to the function.
addEntryBlock()2165 Block *LLVMFuncOp::addEntryBlock() {
2166 assert(empty() && "function already has an entry block");
2167
2168 auto *entry = new Block;
2169 push_back(entry);
2170
2171 // FIXME: Allow passing in proper locations for the entry arguments.
2172 LLVMFunctionType type = getFunctionType();
2173 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2174 entry->addArgument(type.getParamType(i), getLoc());
2175 return entry;
2176 }
2177
build(OpBuilder & builder,OperationState & result,StringRef name,Type type,LLVM::Linkage linkage,bool dsoLocal,CConv cconv,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)2178 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2179 StringRef name, Type type, LLVM::Linkage linkage,
2180 bool dsoLocal, CConv cconv,
2181 ArrayRef<NamedAttribute> attrs,
2182 ArrayRef<DictionaryAttr> argAttrs) {
2183 result.addRegion();
2184 result.addAttribute(SymbolTable::getSymbolAttrName(),
2185 builder.getStringAttr(name));
2186 result.addAttribute(getFunctionTypeAttrName(result.name),
2187 TypeAttr::get(type));
2188 result.addAttribute(getLinkageAttrName(result.name),
2189 LinkageAttr::get(builder.getContext(), linkage));
2190 result.addAttribute(getCConvAttrName(result.name),
2191 CConvAttr::get(builder.getContext(), cconv));
2192 result.attributes.append(attrs.begin(), attrs.end());
2193 if (dsoLocal)
2194 result.addAttribute("dso_local", builder.getUnitAttr());
2195 if (argAttrs.empty())
2196 return;
2197
2198 assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
2199 "expected as many argument attribute lists as arguments");
2200 function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
2201 /*resultAttrs=*/llvm::None);
2202 }
2203
2204 // Builds an LLVM function type from the given lists of input and output types.
2205 // Returns a null type if any of the types provided are non-LLVM types, or if
2206 // there is more than one output type.
2207 static Type
buildLLVMFunctionType(OpAsmParser & parser,SMLoc loc,ArrayRef<Type> inputs,ArrayRef<Type> outputs,function_interface_impl::VariadicFlag variadicFlag)2208 buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
2209 ArrayRef<Type> outputs,
2210 function_interface_impl::VariadicFlag variadicFlag) {
2211 Builder &b = parser.getBuilder();
2212 if (outputs.size() > 1) {
2213 parser.emitError(loc, "failed to construct function type: expected zero or "
2214 "one function result");
2215 return {};
2216 }
2217
2218 // Convert inputs to LLVM types, exit early on error.
2219 SmallVector<Type, 4> llvmInputs;
2220 for (auto t : inputs) {
2221 if (!isCompatibleType(t)) {
2222 parser.emitError(loc, "failed to construct function type: expected LLVM "
2223 "type for function arguments");
2224 return {};
2225 }
2226 llvmInputs.push_back(t);
2227 }
2228
2229 // No output is denoted as "void" in LLVM type system.
2230 Type llvmOutput =
2231 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2232 if (!isCompatibleType(llvmOutput)) {
2233 parser.emitError(loc, "failed to construct function type: expected LLVM "
2234 "type for function results")
2235 << llvmOutput;
2236 return {};
2237 }
2238 return LLVMFunctionType::get(llvmOutput, llvmInputs,
2239 variadicFlag.isVariadic());
2240 }
2241
2242 // Parses an LLVM function.
2243 //
2244 // operation ::= `llvm.func` linkage? cconv? function-signature
2245 // function-attributes?
2246 // function-body
2247 //
parse(OpAsmParser & parser,OperationState & result)2248 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2249 // Default to external linkage if no keyword is provided.
2250 result.addAttribute(
2251 getLinkageAttrName(result.name),
2252 LinkageAttr::get(parser.getContext(),
2253 parseOptionalLLVMKeyword<Linkage>(
2254 parser, result, LLVM::Linkage::External)));
2255
2256 // Default to C Calling Convention if no keyword is provided.
2257 result.addAttribute(
2258 getCConvAttrName(result.name),
2259 CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
2260 parser, result, LLVM::CConv::C)));
2261
2262 StringAttr nameAttr;
2263 SmallVector<OpAsmParser::Argument> entryArgs;
2264 SmallVector<DictionaryAttr> resultAttrs;
2265 SmallVector<Type> resultTypes;
2266 bool isVariadic;
2267
2268 auto signatureLocation = parser.getCurrentLocation();
2269 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2270 result.attributes) ||
2271 function_interface_impl::parseFunctionSignature(
2272 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2273 resultAttrs))
2274 return failure();
2275
2276 SmallVector<Type> argTypes;
2277 for (auto &arg : entryArgs)
2278 argTypes.push_back(arg.type);
2279 auto type =
2280 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2281 function_interface_impl::VariadicFlag(isVariadic));
2282 if (!type)
2283 return failure();
2284 result.addAttribute(FunctionOpInterface::getTypeAttrName(),
2285 TypeAttr::get(type));
2286
2287 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2288 return failure();
2289 function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
2290 entryArgs, resultAttrs);
2291
2292 auto *body = result.addRegion();
2293 OptionalParseResult parseResult =
2294 parser.parseOptionalRegion(*body, entryArgs);
2295 return failure(parseResult.hasValue() && failed(*parseResult));
2296 }
2297
2298 // Print the LLVMFuncOp. Collects argument and result types and passes them to
2299 // helper functions. Drops "void" result since it cannot be parsed back. Skips
2300 // the external linkage since it is the default value.
print(OpAsmPrinter & p)2301 void LLVMFuncOp::print(OpAsmPrinter &p) {
2302 p << ' ';
2303 if (getLinkage() != LLVM::Linkage::External)
2304 p << stringifyLinkage(getLinkage()) << ' ';
2305 if (getCConv() != LLVM::CConv::C)
2306 p << stringifyCConv(getCConv()) << ' ';
2307
2308 p.printSymbolName(getName());
2309
2310 LLVMFunctionType fnType = getFunctionType();
2311 SmallVector<Type, 8> argTypes;
2312 SmallVector<Type, 1> resTypes;
2313 argTypes.reserve(fnType.getNumParams());
2314 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2315 argTypes.push_back(fnType.getParamType(i));
2316
2317 Type returnType = fnType.getReturnType();
2318 if (!returnType.isa<LLVMVoidType>())
2319 resTypes.push_back(returnType);
2320
2321 function_interface_impl::printFunctionSignature(p, *this, argTypes,
2322 isVarArg(), resTypes);
2323 function_interface_impl::printFunctionAttributes(
2324 p, *this, argTypes.size(), resTypes.size(),
2325 {getLinkageAttrName(), getCConvAttrName()});
2326
2327 // Print the body if this is not an external function.
2328 Region &body = getBody();
2329 if (!body.empty()) {
2330 p << ' ';
2331 p.printRegion(body, /*printEntryBlockArgs=*/false,
2332 /*printBlockTerminators=*/true);
2333 }
2334 }
2335
2336 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2337 // - functions don't have 'common' linkage
2338 // - external functions have 'external' or 'extern_weak' linkage;
2339 // - vararg is (currently) only supported for external functions;
verify()2340 LogicalResult LLVMFuncOp::verify() {
2341 if (getLinkage() == LLVM::Linkage::Common)
2342 return emitOpError() << "functions cannot have '"
2343 << stringifyLinkage(LLVM::Linkage::Common)
2344 << "' linkage";
2345
2346 // Check to see if this function has a void return with a result attribute to
2347 // it. It isn't clear what semantics we would assign to that.
2348 if (getFunctionType().getReturnType().isa<LLVMVoidType>() &&
2349 !getResultAttrs(0).empty()) {
2350 return emitOpError()
2351 << "cannot attach result attributes to functions with a void return";
2352 }
2353
2354 if (isExternal()) {
2355 if (getLinkage() != LLVM::Linkage::External &&
2356 getLinkage() != LLVM::Linkage::ExternWeak)
2357 return emitOpError() << "external functions must have '"
2358 << stringifyLinkage(LLVM::Linkage::External)
2359 << "' or '"
2360 << stringifyLinkage(LLVM::Linkage::ExternWeak)
2361 << "' linkage";
2362 return success();
2363 }
2364
2365 return success();
2366 }
2367
2368 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2369 /// - entry block arguments are of LLVM types.
verifyRegions()2370 LogicalResult LLVMFuncOp::verifyRegions() {
2371 if (isExternal())
2372 return success();
2373
2374 unsigned numArguments = getFunctionType().getNumParams();
2375 Block &entryBlock = front();
2376 for (unsigned i = 0; i < numArguments; ++i) {
2377 Type argType = entryBlock.getArgument(i).getType();
2378 if (!isCompatibleType(argType))
2379 return emitOpError("entry block argument #")
2380 << i << " is not of LLVM type";
2381 }
2382
2383 return success();
2384 }
2385
2386 //===----------------------------------------------------------------------===//
2387 // Verification for LLVM::ConstantOp.
2388 //===----------------------------------------------------------------------===//
2389
verify()2390 LogicalResult LLVM::ConstantOp::verify() {
2391 if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
2392 auto arrayType = getType().dyn_cast<LLVMArrayType>();
2393 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
2394 !arrayType.getElementType().isInteger(8)) {
2395 return emitOpError() << "expected array type of "
2396 << sAttr.getValue().size()
2397 << " i8 elements for the string constant";
2398 }
2399 return success();
2400 }
2401 if (auto structType = getType().dyn_cast<LLVMStructType>()) {
2402 if (structType.getBody().size() != 2 ||
2403 structType.getBody()[0] != structType.getBody()[1]) {
2404 return emitError() << "expected struct type with two elements of the "
2405 "same type, the type of a complex constant";
2406 }
2407
2408 auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
2409 if (!arrayAttr || arrayAttr.size() != 2 ||
2410 arrayAttr[0].getType() != arrayAttr[1].getType()) {
2411 return emitOpError() << "expected array attribute with two elements, "
2412 "representing a complex constant";
2413 }
2414
2415 Type elementType = structType.getBody()[0];
2416 if (!elementType
2417 .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
2418 return emitError()
2419 << "expected struct element types to be floating point type or "
2420 "integer type";
2421 }
2422 return success();
2423 }
2424 if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
2425 return emitOpError()
2426 << "only supports integer, float, string or elements attributes";
2427 return success();
2428 }
2429
2430 // Constant op constant-folds to its value.
fold(ArrayRef<Attribute>)2431 OpFoldResult LLVM::ConstantOp::fold(ArrayRef<Attribute>) { return getValue(); }
2432
2433 //===----------------------------------------------------------------------===//
2434 // Utility functions for parsing atomic ops
2435 //===----------------------------------------------------------------------===//
2436
2437 // Helper function to parse a keyword into the specified attribute named by
2438 // `attrName`. The keyword must match one of the string values defined by the
2439 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
2440 // state.
parseAtomicBinOp(OpAsmParser & parser,OperationState & result,StringRef attrName)2441 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
2442 StringRef attrName) {
2443 SMLoc loc;
2444 StringRef keyword;
2445 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
2446 return failure();
2447
2448 // Replace the keyword `keyword` with an integer attribute.
2449 auto kind = symbolizeAtomicBinOp(keyword);
2450 if (!kind) {
2451 return parser.emitError(loc)
2452 << "'" << keyword << "' is an incorrect value of the '" << attrName
2453 << "' attribute";
2454 }
2455
2456 auto value = static_cast<int64_t>(*kind);
2457 auto attr = parser.getBuilder().getI64IntegerAttr(value);
2458 result.addAttribute(attrName, attr);
2459
2460 return success();
2461 }
2462
2463 // Helper function to parse a keyword into the specified attribute named by
2464 // `attrName`. The keyword must match one of the string values defined by the
2465 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
2466 // state.
parseAtomicOrdering(OpAsmParser & parser,OperationState & result,StringRef attrName)2467 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
2468 OperationState &result,
2469 StringRef attrName) {
2470 SMLoc loc;
2471 StringRef ordering;
2472 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
2473 return failure();
2474
2475 // Replace the keyword `ordering` with an integer attribute.
2476 auto kind = symbolizeAtomicOrdering(ordering);
2477 if (!kind) {
2478 return parser.emitError(loc)
2479 << "'" << ordering << "' is an incorrect value of the '" << attrName
2480 << "' attribute";
2481 }
2482
2483 auto value = static_cast<int64_t>(*kind);
2484 auto attr = parser.getBuilder().getI64IntegerAttr(value);
2485 result.addAttribute(attrName, attr);
2486
2487 return success();
2488 }
2489
2490 //===----------------------------------------------------------------------===//
2491 // Printer, parser and verifier for LLVM::AtomicRMWOp.
2492 //===----------------------------------------------------------------------===//
2493
print(OpAsmPrinter & p)2494 void AtomicRMWOp::print(OpAsmPrinter &p) {
2495 p << ' ' << stringifyAtomicBinOp(getBinOp()) << ' ' << getPtr() << ", "
2496 << getVal() << ' ' << stringifyAtomicOrdering(getOrdering()) << ' ';
2497 p.printOptionalAttrDict((*this)->getAttrs(), {"bin_op", "ordering"});
2498 p << " : " << getRes().getType();
2499 }
2500
2501 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
2502 // attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)2503 ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) {
2504 Type type;
2505 OpAsmParser::UnresolvedOperand ptr, val;
2506 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
2507 parser.parseComma() || parser.parseOperand(val) ||
2508 parseAtomicOrdering(parser, result, "ordering") ||
2509 parser.parseOptionalAttrDict(result.attributes) ||
2510 parser.parseColonType(type) ||
2511 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2512 result.operands) ||
2513 parser.resolveOperand(val, type, result.operands))
2514 return failure();
2515
2516 result.addTypes(type);
2517 return success();
2518 }
2519
verify()2520 LogicalResult AtomicRMWOp::verify() {
2521 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2522 auto valType = getVal().getType();
2523 if (valType != ptrType.getElementType())
2524 return emitOpError("expected LLVM IR element type for operand #0 to "
2525 "match type for operand #1");
2526 auto resType = getRes().getType();
2527 if (resType != valType)
2528 return emitOpError(
2529 "expected LLVM IR result type to match type for operand #1");
2530 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) {
2531 if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
2532 return emitOpError("expected LLVM IR floating point type");
2533 } else if (getBinOp() == AtomicBinOp::xchg) {
2534 auto intType = valType.dyn_cast<IntegerType>();
2535 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2536 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2537 intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
2538 !valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
2539 !valType.isa<Float64Type>())
2540 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
2541 } else {
2542 auto intType = valType.dyn_cast<IntegerType>();
2543 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2544 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
2545 intBitWidth != 64)
2546 return emitOpError("expected LLVM IR integer type");
2547 }
2548
2549 if (static_cast<unsigned>(getOrdering()) <
2550 static_cast<unsigned>(AtomicOrdering::monotonic))
2551 return emitOpError() << "expected at least '"
2552 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
2553 << "' ordering";
2554
2555 return success();
2556 }
2557
2558 //===----------------------------------------------------------------------===//
2559 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
2560 //===----------------------------------------------------------------------===//
2561
print(OpAsmPrinter & p)2562 void AtomicCmpXchgOp::print(OpAsmPrinter &p) {
2563 p << ' ' << getPtr() << ", " << getCmp() << ", " << getVal() << ' '
2564 << stringifyAtomicOrdering(getSuccessOrdering()) << ' '
2565 << stringifyAtomicOrdering(getFailureOrdering());
2566 p.printOptionalAttrDict((*this)->getAttrs(),
2567 {"success_ordering", "failure_ordering"});
2568 p << " : " << getVal().getType();
2569 }
2570
2571 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
2572 // keyword keyword attribute-dict? `:` type
parse(OpAsmParser & parser,OperationState & result)2573 ParseResult AtomicCmpXchgOp::parse(OpAsmParser &parser,
2574 OperationState &result) {
2575 auto &builder = parser.getBuilder();
2576 Type type;
2577 OpAsmParser::UnresolvedOperand ptr, cmp, val;
2578 if (parser.parseOperand(ptr) || parser.parseComma() ||
2579 parser.parseOperand(cmp) || parser.parseComma() ||
2580 parser.parseOperand(val) ||
2581 parseAtomicOrdering(parser, result, "success_ordering") ||
2582 parseAtomicOrdering(parser, result, "failure_ordering") ||
2583 parser.parseOptionalAttrDict(result.attributes) ||
2584 parser.parseColonType(type) ||
2585 parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type),
2586 result.operands) ||
2587 parser.resolveOperand(cmp, type, result.operands) ||
2588 parser.resolveOperand(val, type, result.operands))
2589 return failure();
2590
2591 auto boolType = IntegerType::get(builder.getContext(), 1);
2592 auto resultType =
2593 LLVMStructType::getLiteral(builder.getContext(), {type, boolType});
2594 result.addTypes(resultType);
2595
2596 return success();
2597 }
2598
verify()2599 LogicalResult AtomicCmpXchgOp::verify() {
2600 auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
2601 if (!ptrType)
2602 return emitOpError("expected LLVM IR pointer type for operand #0");
2603 auto cmpType = getCmp().getType();
2604 auto valType = getVal().getType();
2605 if (cmpType != ptrType.getElementType() || cmpType != valType)
2606 return emitOpError("expected LLVM IR element type for operand #0 to "
2607 "match type for all other operands");
2608 auto intType = valType.dyn_cast<IntegerType>();
2609 unsigned intBitWidth = intType ? intType.getWidth() : 0;
2610 if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
2611 intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
2612 !valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
2613 !valType.isa<Float32Type>() && !valType.isa<Float64Type>())
2614 return emitOpError("unexpected LLVM IR type");
2615 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
2616 getFailureOrdering() < AtomicOrdering::monotonic)
2617 return emitOpError("ordering must be at least 'monotonic'");
2618 if (getFailureOrdering() == AtomicOrdering::release ||
2619 getFailureOrdering() == AtomicOrdering::acq_rel)
2620 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
2621 return success();
2622 }
2623
2624 //===----------------------------------------------------------------------===//
2625 // Printer, parser and verifier for LLVM::FenceOp.
2626 //===----------------------------------------------------------------------===//
2627
2628 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
2629 // attribute-dict?
parse(OpAsmParser & parser,OperationState & result)2630 ParseResult FenceOp::parse(OpAsmParser &parser, OperationState &result) {
2631 StringAttr sScope;
2632 StringRef syncscopeKeyword = "syncscope";
2633 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
2634 if (parser.parseLParen() ||
2635 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
2636 parser.parseRParen())
2637 return failure();
2638 } else {
2639 result.addAttribute(syncscopeKeyword,
2640 parser.getBuilder().getStringAttr(""));
2641 }
2642 if (parseAtomicOrdering(parser, result, "ordering") ||
2643 parser.parseOptionalAttrDict(result.attributes))
2644 return failure();
2645 return success();
2646 }
2647
print(OpAsmPrinter & p)2648 void FenceOp::print(OpAsmPrinter &p) {
2649 StringRef syncscopeKeyword = "syncscope";
2650 p << ' ';
2651 if (!(*this)->getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
2652 p << "syncscope(" << (*this)->getAttr(syncscopeKeyword) << ") ";
2653 p << stringifyAtomicOrdering(getOrdering());
2654 }
2655
verify()2656 LogicalResult FenceOp::verify() {
2657 if (getOrdering() == AtomicOrdering::not_atomic ||
2658 getOrdering() == AtomicOrdering::unordered ||
2659 getOrdering() == AtomicOrdering::monotonic)
2660 return emitOpError("can be given only acquire, release, acq_rel, "
2661 "and seq_cst orderings");
2662 return success();
2663 }
2664
2665 //===----------------------------------------------------------------------===//
2666 // Folder for LLVM::BitcastOp
2667 //===----------------------------------------------------------------------===//
2668
fold(ArrayRef<Attribute> operands)2669 OpFoldResult LLVM::BitcastOp::fold(ArrayRef<Attribute> operands) {
2670 // bitcast(x : T0, T0) -> x
2671 if (getArg().getType() == getType())
2672 return getArg();
2673 // bitcast(bitcast(x : T0, T1), T0) -> x
2674 if (auto prev = getArg().getDefiningOp<BitcastOp>())
2675 if (prev.getArg().getType() == getType())
2676 return prev.getArg();
2677 return {};
2678 }
2679
2680 //===----------------------------------------------------------------------===//
2681 // Folder for LLVM::AddrSpaceCastOp
2682 //===----------------------------------------------------------------------===//
2683
fold(ArrayRef<Attribute> operands)2684 OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef<Attribute> operands) {
2685 // addrcast(x : T0, T0) -> x
2686 if (getArg().getType() == getType())
2687 return getArg();
2688 // addrcast(addrcast(x : T0, T1), T0) -> x
2689 if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
2690 if (prev.getArg().getType() == getType())
2691 return prev.getArg();
2692 return {};
2693 }
2694
2695 //===----------------------------------------------------------------------===//
2696 // Folder for LLVM::GEPOp
2697 //===----------------------------------------------------------------------===//
2698
fold(ArrayRef<Attribute> operands)2699 OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
2700 // gep %x:T, 0 -> %x
2701 if (getBase().getType() == getType() && getIndices().size() == 1 &&
2702 matchPattern(getIndices()[0], m_Zero()))
2703 return getBase();
2704 return {};
2705 }
2706
2707 //===----------------------------------------------------------------------===//
2708 // LLVMDialect initialization, type parsing, and registration.
2709 //===----------------------------------------------------------------------===//
2710
initialize()2711 void LLVMDialect::initialize() {
2712 addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
2713
2714 // clang-format off
2715 addTypes<LLVMVoidType,
2716 LLVMPPCFP128Type,
2717 LLVMX86MMXType,
2718 LLVMTokenType,
2719 LLVMLabelType,
2720 LLVMMetadataType,
2721 LLVMFunctionType,
2722 LLVMPointerType,
2723 LLVMFixedVectorType,
2724 LLVMScalableVectorType,
2725 LLVMArrayType,
2726 LLVMStructType>();
2727 // clang-format on
2728 addOperations<
2729 #define GET_OP_LIST
2730 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2731 ,
2732 #define GET_OP_LIST
2733 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
2734 >();
2735
2736 // Support unknown operations because not all LLVM operations are registered.
2737 allowUnknownOperations();
2738 }
2739
2740 #define GET_OP_CLASSES
2741 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
2742
2743 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const2744 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
2745 return detail::parseType(parser);
2746 }
2747
2748 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const2749 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
2750 return detail::printType(type, os);
2751 }
2752
verifyDataLayoutString(StringRef descr,llvm::function_ref<void (const Twine &)> reportError)2753 LogicalResult LLVMDialect::verifyDataLayoutString(
2754 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
2755 llvm::Expected<llvm::DataLayout> maybeDataLayout =
2756 llvm::DataLayout::parse(descr);
2757 if (maybeDataLayout)
2758 return success();
2759
2760 std::string message;
2761 llvm::raw_string_ostream messageStream(message);
2762 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
2763 reportError("invalid data layout descriptor: " + messageStream.str());
2764 return failure();
2765 }
2766
2767 /// Verify LLVM dialect attributes.
verifyOperationAttribute(Operation * op,NamedAttribute attr)2768 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
2769 NamedAttribute attr) {
2770 // If the `llvm.loop` attribute is present, enforce the following structure,
2771 // which the module translation can assume.
2772 if (attr.getName() == LLVMDialect::getLoopAttrName()) {
2773 auto loopAttr = attr.getValue().dyn_cast<DictionaryAttr>();
2774 if (!loopAttr)
2775 return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName()
2776 << "' to be a dictionary attribute";
2777 Optional<NamedAttribute> parallelAccessGroup =
2778 loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
2779 if (parallelAccessGroup) {
2780 auto accessGroups = parallelAccessGroup->getValue().dyn_cast<ArrayAttr>();
2781 if (!accessGroups)
2782 return op->emitOpError()
2783 << "expected '" << LLVMDialect::getParallelAccessAttrName()
2784 << "' to be an array attribute";
2785 for (Attribute attr : accessGroups) {
2786 auto accessGroupRef = attr.dyn_cast<SymbolRefAttr>();
2787 if (!accessGroupRef)
2788 return op->emitOpError()
2789 << "expected '" << attr << "' to be a symbol reference";
2790 StringAttr metadataName = accessGroupRef.getRootReference();
2791 auto metadataOp =
2792 SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
2793 op->getParentOp(), metadataName);
2794 if (!metadataOp)
2795 return op->emitOpError()
2796 << "expected '" << attr << "' to reference a metadata op";
2797 StringAttr accessGroupName = accessGroupRef.getLeafReference();
2798 Operation *accessGroupOp =
2799 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
2800 if (!accessGroupOp)
2801 return op->emitOpError()
2802 << "expected '" << attr << "' to reference an access_group op";
2803 }
2804 }
2805
2806 Optional<NamedAttribute> loopOptions =
2807 loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName());
2808 if (loopOptions && !loopOptions->getValue().isa<LoopOptionsAttr>())
2809 return op->emitOpError()
2810 << "expected '" << LLVMDialect::getLoopOptionsAttrName()
2811 << "' to be a `loopopts` attribute";
2812 }
2813
2814 if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2815 return op->emitOpError()
2816 << "'" << LLVM::LLVMDialect::getStructAttrsAttrName()
2817 << "' is permitted only in argument or result attributes";
2818 }
2819
2820 // If the data layout attribute is present, it must use the LLVM data layout
2821 // syntax. Try parsing it and report errors in case of failure. Users of this
2822 // attribute may assume it is well-formed and can pass it to the (asserting)
2823 // llvm::DataLayout constructor.
2824 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
2825 return success();
2826 if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>())
2827 return verifyDataLayoutString(
2828 stringAttr.getValue(),
2829 [op](const Twine &message) { op->emitOpError() << message.str(); });
2830
2831 return op->emitOpError() << "expected '"
2832 << LLVM::LLVMDialect::getDataLayoutAttrName()
2833 << "' to be a string attributes";
2834 }
2835
verifyStructAttr(Operation * op,Attribute attr,Type annotatedType)2836 LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr,
2837 Type annotatedType) {
2838 auto structType = annotatedType.dyn_cast<LLVMStructType>();
2839 if (!structType) {
2840 const auto emitIncorrectAnnotatedType = [&op]() {
2841 return op->emitError()
2842 << "expected '" << LLVMDialect::getStructAttrsAttrName()
2843 << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'";
2844 };
2845 const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>();
2846 if (!ptrType)
2847 return emitIncorrectAnnotatedType();
2848 structType = ptrType.getElementType().dyn_cast<LLVMStructType>();
2849 if (!structType)
2850 return emitIncorrectAnnotatedType();
2851 }
2852
2853 const auto arrAttrs = attr.dyn_cast<ArrayAttr>();
2854 if (!arrAttrs)
2855 return op->emitError() << "expected '"
2856 << LLVMDialect::getStructAttrsAttrName()
2857 << "' to be an array attribute";
2858
2859 if (structType.getBody().size() != arrAttrs.size())
2860 return op->emitError()
2861 << "size of '" << LLVMDialect::getStructAttrsAttrName()
2862 << "' must match the size of the annotated '!llvm.struct'";
2863 return success();
2864 }
2865
verifyFuncOpInterfaceStructAttr(Operation * op,Attribute attr,const std::function<Type (FunctionOpInterface)> & getAnnotatedType)2866 static LogicalResult verifyFuncOpInterfaceStructAttr(
2867 Operation *op, Attribute attr,
2868 const std::function<Type(FunctionOpInterface)> &getAnnotatedType) {
2869 if (auto funcOp = dyn_cast<FunctionOpInterface>(op))
2870 return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp));
2871 return op->emitError() << "expected '"
2872 << LLVMDialect::getStructAttrsAttrName()
2873 << "' to be used on function-like operations";
2874 }
2875
2876 /// Verify LLVMIR function argument attributes.
verifyRegionArgAttribute(Operation * op,unsigned regionIdx,unsigned argIdx,NamedAttribute argAttr)2877 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
2878 unsigned regionIdx,
2879 unsigned argIdx,
2880 NamedAttribute argAttr) {
2881 // Check that llvm.noalias is a unit attribute.
2882 if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() &&
2883 !argAttr.getValue().isa<UnitAttr>())
2884 return op->emitError()
2885 << "expected llvm.noalias argument attribute to be a unit attribute";
2886 // Check that llvm.align is an integer attribute.
2887 if (argAttr.getName() == LLVMDialect::getAlignAttrName() &&
2888 !argAttr.getValue().isa<IntegerAttr>())
2889 return op->emitError()
2890 << "llvm.align argument attribute of non integer type";
2891 if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2892 return verifyFuncOpInterfaceStructAttr(
2893 op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) {
2894 return funcOp.getArgumentTypes()[argIdx];
2895 });
2896 }
2897 return success();
2898 }
2899
verifyRegionResultAttribute(Operation * op,unsigned regionIdx,unsigned resIdx,NamedAttribute resAttr)2900 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
2901 unsigned regionIdx,
2902 unsigned resIdx,
2903 NamedAttribute resAttr) {
2904 if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
2905 return verifyFuncOpInterfaceStructAttr(
2906 op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) {
2907 return funcOp.getResultTypes()[resIdx];
2908 });
2909 }
2910 return success();
2911 }
2912
2913 //===----------------------------------------------------------------------===//
2914 // Utility functions.
2915 //===----------------------------------------------------------------------===//
2916
createGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,LLVM::Linkage linkage)2917 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
2918 StringRef name, StringRef value,
2919 LLVM::Linkage linkage) {
2920 assert(builder.getInsertionBlock() &&
2921 builder.getInsertionBlock()->getParentOp() &&
2922 "expected builder to point to a block constrained in an op");
2923 auto module =
2924 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
2925 assert(module && "builder points to an op outside of a module");
2926
2927 // Create the global at the entry of the module.
2928 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
2929 MLIRContext *ctx = builder.getContext();
2930 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
2931 auto global = moduleBuilder.create<LLVM::GlobalOp>(
2932 loc, type, /*isConstant=*/true, linkage, name,
2933 builder.getStringAttr(value), /*alignment=*/0);
2934
2935 // Get the pointer to the first character in the global string.
2936 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
2937 Value cst0 = builder.create<LLVM::ConstantOp>(
2938 loc, IntegerType::get(ctx, 64),
2939 builder.getIntegerAttr(builder.getIndexType(), 0));
2940 return builder.create<LLVM::GEPOp>(
2941 loc, LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)), globalPtr,
2942 ValueRange{cst0, cst0});
2943 }
2944
satisfiesLLVMModule(Operation * op)2945 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
2946 return op->hasTrait<OpTrait::SymbolTable>() &&
2947 op->hasTrait<OpTrait::IsIsolatedFromAbove>();
2948 }
2949
print(AsmPrinter & printer) const2950 void FMFAttr::print(AsmPrinter &printer) const {
2951 printer << "<";
2952 printer << stringifyFastmathFlags(this->getFlags());
2953 printer << ">";
2954 }
2955
parse(AsmParser & parser,Type type)2956 Attribute FMFAttr::parse(AsmParser &parser, Type type) {
2957 if (failed(parser.parseLess()))
2958 return {};
2959
2960 FastmathFlags flags = {};
2961 if (failed(parser.parseOptionalGreater())) {
2962 auto parseFlags = [&]() -> ParseResult {
2963 StringRef elemName;
2964 if (failed(parser.parseKeyword(&elemName)))
2965 return failure();
2966
2967 auto elem = symbolizeFastmathFlags(elemName);
2968 if (!elem)
2969 return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
2970 << elemName;
2971
2972 flags = flags | *elem;
2973 return success();
2974 };
2975 if (failed(parser.parseCommaSeparatedList(parseFlags)) ||
2976 parser.parseGreater())
2977 return {};
2978 }
2979
2980 return FMFAttr::get(parser.getContext(), flags);
2981 }
2982
print(AsmPrinter & printer) const2983 void LinkageAttr::print(AsmPrinter &printer) const {
2984 printer << "<";
2985 if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
2986 printer << stringifyEnum(getLinkage());
2987 else
2988 printer << static_cast<uint64_t>(getLinkage());
2989 printer << ">";
2990 }
2991
parse(AsmParser & parser,Type type)2992 Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
2993 StringRef elemName;
2994 if (parser.parseLess() || parser.parseKeyword(&elemName) ||
2995 parser.parseGreater())
2996 return {};
2997 auto elem = linkage::symbolizeLinkage(elemName);
2998 if (!elem) {
2999 parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
3000 return {};
3001 }
3002 Linkage linkage = *elem;
3003 return LinkageAttr::get(parser.getContext(), linkage);
3004 }
3005
print(AsmPrinter & printer) const3006 void CConvAttr::print(AsmPrinter &printer) const {
3007 printer << "<";
3008 if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv())
3009 printer << stringifyEnum(getCallingConv());
3010 else
3011 printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv());
3012 printer << ">";
3013 }
3014
parse(AsmParser & parser,Type type)3015 Attribute CConvAttr::parse(AsmParser &parser, Type type) {
3016 StringRef convName;
3017
3018 if (parser.parseLess() || parser.parseKeyword(&convName) ||
3019 parser.parseGreater())
3020 return {};
3021 auto cconv = cconv::symbolizeCConv(convName);
3022 if (!cconv) {
3023 parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
3024 << convName;
3025 return {};
3026 }
3027 CConv cconvVal = *cconv;
3028 return CConvAttr::get(parser.getContext(), cconvVal);
3029 }
3030
LoopOptionsAttrBuilder(LoopOptionsAttr attr)3031 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
3032 : options(attr.getOptions().begin(), attr.getOptions().end()) {}
3033
3034 template <typename T>
setOption(LoopOptionCase tag,Optional<T> value)3035 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
3036 Optional<T> value) {
3037 auto option = llvm::find_if(
3038 options, [tag](auto option) { return option.first == tag; });
3039 if (option != options.end()) {
3040 if (value)
3041 option->second = *value;
3042 else
3043 options.erase(option);
3044 } else {
3045 options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
3046 }
3047 return *this;
3048 }
3049
3050 LoopOptionsAttrBuilder &
setDisableLICM(Optional<bool> value)3051 LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
3052 return setOption(LoopOptionCase::disable_licm, value);
3053 }
3054
3055 /// Set the `interleave_count` option to the provided value. If no value
3056 /// is provided the option is deleted.
3057 LoopOptionsAttrBuilder &
setInterleaveCount(Optional<uint64_t> count)3058 LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
3059 return setOption(LoopOptionCase::interleave_count, count);
3060 }
3061
3062 /// Set the `disable_unroll` option to the provided value. If no value
3063 /// is provided the option is deleted.
3064 LoopOptionsAttrBuilder &
setDisableUnroll(Optional<bool> value)3065 LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
3066 return setOption(LoopOptionCase::disable_unroll, value);
3067 }
3068
3069 /// Set the `disable_pipeline` option to the provided value. If no value
3070 /// is provided the option is deleted.
3071 LoopOptionsAttrBuilder &
setDisablePipeline(Optional<bool> value)3072 LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
3073 return setOption(LoopOptionCase::disable_pipeline, value);
3074 }
3075
3076 /// Set the `pipeline_initiation_interval` option to the provided value.
3077 /// If no value is provided the option is deleted.
setPipelineInitiationInterval(Optional<uint64_t> count)3078 LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
3079 Optional<uint64_t> count) {
3080 return setOption(LoopOptionCase::pipeline_initiation_interval, count);
3081 }
3082
3083 template <typename T>
3084 static Optional<T>
getOption(ArrayRef<std::pair<LoopOptionCase,int64_t>> options,LoopOptionCase option)3085 getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
3086 LoopOptionCase option) {
3087 auto it =
3088 lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
3089 return optionPair.first < option;
3090 });
3091 if (it == options.end())
3092 return {};
3093 return static_cast<T>(it->second);
3094 }
3095
disableUnroll()3096 Optional<bool> LoopOptionsAttr::disableUnroll() {
3097 return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
3098 }
3099
disableLICM()3100 Optional<bool> LoopOptionsAttr::disableLICM() {
3101 return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
3102 }
3103
interleaveCount()3104 Optional<int64_t> LoopOptionsAttr::interleaveCount() {
3105 return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
3106 }
3107
3108 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,ArrayRef<std::pair<LoopOptionCase,int64_t>> sortedOptions)3109 LoopOptionsAttr LoopOptionsAttr::get(
3110 MLIRContext *context,
3111 ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
3112 assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
3113 "LoopOptionsAttr ctor expects a sorted options array");
3114 return Base::get(context, sortedOptions);
3115 }
3116
3117 /// Build the LoopOptions Attribute from a sorted array of individual options.
get(MLIRContext * context,LoopOptionsAttrBuilder & optionBuilders)3118 LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
3119 LoopOptionsAttrBuilder &optionBuilders) {
3120 llvm::sort(optionBuilders.options, llvm::less_first());
3121 return Base::get(context, optionBuilders.options);
3122 }
3123
print(AsmPrinter & printer) const3124 void LoopOptionsAttr::print(AsmPrinter &printer) const {
3125 printer << "<";
3126 llvm::interleaveComma(getOptions(), printer, [&](auto option) {
3127 printer << stringifyEnum(option.first) << " = ";
3128 switch (option.first) {
3129 case LoopOptionCase::disable_licm:
3130 case LoopOptionCase::disable_unroll:
3131 case LoopOptionCase::disable_pipeline:
3132 printer << (option.second ? "true" : "false");
3133 break;
3134 case LoopOptionCase::interleave_count:
3135 case LoopOptionCase::pipeline_initiation_interval:
3136 printer << option.second;
3137 break;
3138 }
3139 });
3140 printer << ">";
3141 }
3142
parse(AsmParser & parser,Type type)3143 Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
3144 if (failed(parser.parseLess()))
3145 return {};
3146
3147 SmallVector<std::pair<LoopOptionCase, int64_t>> options;
3148 llvm::SmallDenseSet<LoopOptionCase> seenOptions;
3149 auto parseLoopOptions = [&]() -> ParseResult {
3150 StringRef optionName;
3151 if (parser.parseKeyword(&optionName))
3152 return failure();
3153
3154 auto option = symbolizeLoopOptionCase(optionName);
3155 if (!option)
3156 return parser.emitError(parser.getNameLoc(), "unknown loop option: ")
3157 << optionName;
3158 if (!seenOptions.insert(*option).second)
3159 return parser.emitError(parser.getNameLoc(), "loop option present twice");
3160 if (failed(parser.parseEqual()))
3161 return failure();
3162
3163 int64_t value;
3164 switch (*option) {
3165 case LoopOptionCase::disable_licm:
3166 case LoopOptionCase::disable_unroll:
3167 case LoopOptionCase::disable_pipeline:
3168 if (succeeded(parser.parseOptionalKeyword("true")))
3169 value = 1;
3170 else if (succeeded(parser.parseOptionalKeyword("false")))
3171 value = 0;
3172 else {
3173 return parser.emitError(parser.getNameLoc(),
3174 "expected boolean value 'true' or 'false'");
3175 }
3176 break;
3177 case LoopOptionCase::interleave_count:
3178 case LoopOptionCase::pipeline_initiation_interval:
3179 if (failed(parser.parseInteger(value)))
3180 return parser.emitError(parser.getNameLoc(), "expected integer value");
3181 break;
3182 }
3183 options.push_back(std::make_pair(*option, value));
3184 return success();
3185 };
3186 if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater())
3187 return {};
3188
3189 llvm::sort(options, llvm::less_first());
3190 return get(parser.getContext(), options);
3191 }
3192