1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
15
16 #include "DebugTranslation.h"
17 #include "mlir/Dialect/DLTI/DLTI.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/RegionGraphTraits.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
27 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
28 #include "llvm/ADT/TypeSwitch.h"
29
30 #include "llvm/ADT/PostOrderIterator.h"
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/CFG.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/DerivedTypes.h"
37 #include "llvm/IR/IRBuilder.h"
38 #include "llvm/IR/InlineAsm.h"
39 #include "llvm/IR/IntrinsicsNVPTX.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "llvm/IR/MDBuilder.h"
42 #include "llvm/IR/Module.h"
43 #include "llvm/IR/Verifier.h"
44 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
45 #include "llvm/Transforms/Utils/Cloning.h"
46 #include "llvm/Transforms/Utils/ModuleUtils.h"
47
48 using namespace mlir;
49 using namespace mlir::LLVM;
50 using namespace mlir::LLVM::detail;
51
52 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
53
54 /// Translates the given data layout spec attribute to the LLVM IR data layout.
55 /// Only integer, float and endianness entries are currently supported.
56 FailureOr<llvm::DataLayout>
translateDataLayout(DataLayoutSpecInterface attribute,const DataLayout & dataLayout,Optional<Location> loc=llvm::None)57 translateDataLayout(DataLayoutSpecInterface attribute,
58 const DataLayout &dataLayout,
59 Optional<Location> loc = llvm::None) {
60 if (!loc)
61 loc = UnknownLoc::get(attribute.getContext());
62
63 // Translate the endianness attribute.
64 std::string llvmDataLayout;
65 llvm::raw_string_ostream layoutStream(llvmDataLayout);
66 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
67 auto key = entry.getKey().dyn_cast<StringAttr>();
68 if (!key)
69 continue;
70 if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
71 auto value = entry.getValue().cast<StringAttr>();
72 bool isLittleEndian =
73 value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
74 layoutStream << (isLittleEndian ? "e" : "E");
75 layoutStream.flush();
76 continue;
77 }
78 emitError(*loc) << "unsupported data layout key " << key;
79 return failure();
80 }
81
82 // Go through the list of entries to check which types are explicitly
83 // specified in entries. Don't use the entries directly though but query the
84 // data from the layout.
85 for (DataLayoutEntryInterface entry : attribute.getEntries()) {
86 auto type = entry.getKey().dyn_cast<Type>();
87 if (!type)
88 continue;
89 // Data layout for the index type is irrelevant at this point.
90 if (type.isa<IndexType>())
91 continue;
92 FailureOr<std::string> prefix =
93 llvm::TypeSwitch<Type, FailureOr<std::string>>(type)
94 .Case<IntegerType>(
95 [loc](IntegerType integerType) -> FailureOr<std::string> {
96 if (integerType.getSignedness() == IntegerType::Signless)
97 return std::string("i");
98 emitError(*loc)
99 << "unsupported data layout for non-signless integer "
100 << integerType;
101 return failure();
102 })
103 .Case<Float16Type, Float32Type, Float64Type, Float80Type,
104 Float128Type>([](Type) { return std::string("f"); })
105 .Default([loc](Type type) -> FailureOr<std::string> {
106 emitError(*loc) << "unsupported type in data layout: " << type;
107 return failure();
108 });
109 if (failed(prefix))
110 return failure();
111
112 unsigned size = dataLayout.getTypeSizeInBits(type);
113 unsigned abi = dataLayout.getTypeABIAlignment(type) * 8u;
114 unsigned preferred = dataLayout.getTypePreferredAlignment(type) * 8u;
115 layoutStream << "-" << *prefix << size << ":" << abi;
116 if (abi != preferred)
117 layoutStream << ":" << preferred;
118 }
119 layoutStream.flush();
120 StringRef layoutSpec(llvmDataLayout);
121 if (layoutSpec.startswith("-"))
122 layoutSpec = layoutSpec.drop_front();
123
124 return llvm::DataLayout(layoutSpec);
125 }
126
127 /// Builds a constant of a sequential LLVM type `type`, potentially containing
128 /// other sequential types recursively, from the individual constant values
129 /// provided in `constants`. `shape` contains the number of elements in nested
130 /// sequential types. Reports errors at `loc` and returns nullptr on error.
131 static llvm::Constant *
buildSequentialConstant(ArrayRef<llvm::Constant * > & constants,ArrayRef<int64_t> shape,llvm::Type * type,Location loc)132 buildSequentialConstant(ArrayRef<llvm::Constant *> &constants,
133 ArrayRef<int64_t> shape, llvm::Type *type,
134 Location loc) {
135 if (shape.empty()) {
136 llvm::Constant *result = constants.front();
137 constants = constants.drop_front();
138 return result;
139 }
140
141 llvm::Type *elementType;
142 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
143 elementType = arrayTy->getElementType();
144 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
145 elementType = vectorTy->getElementType();
146 } else {
147 emitError(loc) << "expected sequential LLVM types wrapping a scalar";
148 return nullptr;
149 }
150
151 SmallVector<llvm::Constant *, 8> nested;
152 nested.reserve(shape.front());
153 for (int64_t i = 0; i < shape.front(); ++i) {
154 nested.push_back(buildSequentialConstant(constants, shape.drop_front(),
155 elementType, loc));
156 if (!nested.back())
157 return nullptr;
158 }
159
160 if (shape.size() == 1 && type->isVectorTy())
161 return llvm::ConstantVector::get(nested);
162 return llvm::ConstantArray::get(
163 llvm::ArrayType::get(elementType, shape.front()), nested);
164 }
165
166 /// Returns the first non-sequential type nested in sequential types.
getInnermostElementType(llvm::Type * type)167 static llvm::Type *getInnermostElementType(llvm::Type *type) {
168 do {
169 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
170 type = arrayTy->getElementType();
171 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
172 type = vectorTy->getElementType();
173 } else {
174 return type;
175 }
176 } while (true);
177 }
178
179 /// Convert a dense elements attribute to an LLVM IR constant using its raw data
180 /// storage if possible. This supports elements attributes of tensor or vector
181 /// type and avoids constructing separate objects for individual values of the
182 /// innermost dimension. Constants for other dimensions are still constructed
183 /// recursively. Returns null if constructing from raw data is not supported for
184 /// this type, e.g., element type is not a power-of-two-sized primitive. Reports
185 /// other errors at `loc`.
186 static llvm::Constant *
convertDenseElementsAttr(Location loc,DenseElementsAttr denseElementsAttr,llvm::Type * llvmType,const ModuleTranslation & moduleTranslation)187 convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
188 llvm::Type *llvmType,
189 const ModuleTranslation &moduleTranslation) {
190 if (!denseElementsAttr)
191 return nullptr;
192
193 llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
194 if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
195 return nullptr;
196
197 ShapedType type = denseElementsAttr.getType();
198 if (type.getNumElements() == 0)
199 return nullptr;
200
201 // Compute the shape of all dimensions but the innermost. Note that the
202 // innermost dimension may be that of the vector element type.
203 bool hasVectorElementType = type.getElementType().isa<VectorType>();
204 unsigned numAggregates =
205 denseElementsAttr.getNumElements() /
206 (hasVectorElementType ? 1
207 : denseElementsAttr.getType().getShape().back());
208 ArrayRef<int64_t> outerShape = type.getShape();
209 if (!hasVectorElementType)
210 outerShape = outerShape.drop_back();
211
212 // Handle the case of vector splat, LLVM has special support for it.
213 if (denseElementsAttr.isSplat() &&
214 (type.isa<VectorType>() || hasVectorElementType)) {
215 llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
216 innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
217 moduleTranslation);
218 llvm::Constant *splatVector =
219 llvm::ConstantDataVector::getSplat(0, splatValue);
220 SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
221 ArrayRef<llvm::Constant *> constantsRef = constants;
222 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
223 }
224 if (denseElementsAttr.isSplat())
225 return nullptr;
226
227 // In case of non-splat, create a constructor for the innermost constant from
228 // a piece of raw data.
229 std::function<llvm::Constant *(StringRef)> buildCstData;
230 if (type.isa<TensorType>()) {
231 auto vectorElementType = type.getElementType().dyn_cast<VectorType>();
232 if (vectorElementType && vectorElementType.getRank() == 1) {
233 buildCstData = [&](StringRef data) {
234 return llvm::ConstantDataVector::getRaw(
235 data, vectorElementType.getShape().back(), innermostLLVMType);
236 };
237 } else if (!vectorElementType) {
238 buildCstData = [&](StringRef data) {
239 return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
240 innermostLLVMType);
241 };
242 }
243 } else if (type.isa<VectorType>()) {
244 buildCstData = [&](StringRef data) {
245 return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
246 innermostLLVMType);
247 };
248 }
249 if (!buildCstData)
250 return nullptr;
251
252 // Create innermost constants and defer to the default constant creation
253 // mechanism for other dimensions.
254 SmallVector<llvm::Constant *> constants;
255 unsigned aggregateSize = denseElementsAttr.getType().getShape().back() *
256 (innermostLLVMType->getScalarSizeInBits() / 8);
257 constants.reserve(numAggregates);
258 for (unsigned i = 0; i < numAggregates; ++i) {
259 StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
260 aggregateSize);
261 constants.push_back(buildCstData(data));
262 }
263
264 ArrayRef<llvm::Constant *> constantsRef = constants;
265 return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
266 }
267
268 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
269 /// This currently supports integer, floating point, splat and dense element
270 /// attributes and combinations thereof. Also, an array attribute with two
271 /// elements is supported to represent a complex constant. In case of error,
272 /// report it to `loc` and return nullptr.
getLLVMConstant(llvm::Type * llvmType,Attribute attr,Location loc,const ModuleTranslation & moduleTranslation)273 llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
274 llvm::Type *llvmType, Attribute attr, Location loc,
275 const ModuleTranslation &moduleTranslation) {
276 if (!attr)
277 return llvm::UndefValue::get(llvmType);
278 if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
279 auto arrayAttr = attr.dyn_cast<ArrayAttr>();
280 if (!arrayAttr || arrayAttr.size() != 2) {
281 emitError(loc, "expected struct type to be a complex number");
282 return nullptr;
283 }
284 llvm::Type *elementType = structType->getElementType(0);
285 llvm::Constant *real =
286 getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation);
287 if (!real)
288 return nullptr;
289 llvm::Constant *imag =
290 getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation);
291 if (!imag)
292 return nullptr;
293 return llvm::ConstantStruct::get(structType, {real, imag});
294 }
295 // For integer types, we allow a mismatch in sizes as the index type in
296 // MLIR might have a different size than the index type in the LLVM module.
297 if (auto intAttr = attr.dyn_cast<IntegerAttr>())
298 return llvm::ConstantInt::get(
299 llvmType,
300 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
301 if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
302 if (llvmType !=
303 llvm::Type::getFloatingPointTy(llvmType->getContext(),
304 floatAttr.getValue().getSemantics())) {
305 emitError(loc, "FloatAttr does not match expected type of the constant");
306 return nullptr;
307 }
308 return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
309 }
310 if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
311 return llvm::ConstantExpr::getBitCast(
312 moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
313 if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
314 llvm::Type *elementType;
315 uint64_t numElements;
316 bool isScalable = false;
317 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
318 elementType = arrayTy->getElementType();
319 numElements = arrayTy->getNumElements();
320 } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
321 elementType = fVectorTy->getElementType();
322 numElements = fVectorTy->getNumElements();
323 } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
324 elementType = sVectorTy->getElementType();
325 numElements = sVectorTy->getMinNumElements();
326 isScalable = true;
327 } else {
328 llvm_unreachable("unrecognized constant vector type");
329 }
330 // Splat value is a scalar. Extract it only if the element type is not
331 // another sequence type. The recursion terminates because each step removes
332 // one outer sequential type.
333 bool elementTypeSequential =
334 isa<llvm::ArrayType, llvm::VectorType>(elementType);
335 llvm::Constant *child = getLLVMConstant(
336 elementType,
337 elementTypeSequential ? splatAttr
338 : splatAttr.getSplatValue<Attribute>(),
339 loc, moduleTranslation);
340 if (!child)
341 return nullptr;
342 if (llvmType->isVectorTy())
343 return llvm::ConstantVector::getSplat(
344 llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
345 if (llvmType->isArrayTy()) {
346 auto *arrayType = llvm::ArrayType::get(elementType, numElements);
347 SmallVector<llvm::Constant *, 8> constants(numElements, child);
348 return llvm::ConstantArray::get(arrayType, constants);
349 }
350 }
351
352 // Try using raw elements data if possible.
353 if (llvm::Constant *result =
354 convertDenseElementsAttr(loc, attr.dyn_cast<DenseElementsAttr>(),
355 llvmType, moduleTranslation)) {
356 return result;
357 }
358
359 // Fall back to element-by-element construction otherwise.
360 if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
361 assert(elementsAttr.getType().hasStaticShape());
362 assert(!elementsAttr.getType().getShape().empty() &&
363 "unexpected empty elements attribute shape");
364
365 SmallVector<llvm::Constant *, 8> constants;
366 constants.reserve(elementsAttr.getNumElements());
367 llvm::Type *innermostType = getInnermostElementType(llvmType);
368 for (auto n : elementsAttr.getValues<Attribute>()) {
369 constants.push_back(
370 getLLVMConstant(innermostType, n, loc, moduleTranslation));
371 if (!constants.back())
372 return nullptr;
373 }
374 ArrayRef<llvm::Constant *> constantsRef = constants;
375 llvm::Constant *result = buildSequentialConstant(
376 constantsRef, elementsAttr.getType().getShape(), llvmType, loc);
377 assert(constantsRef.empty() && "did not consume all elemental constants");
378 return result;
379 }
380
381 if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
382 return llvm::ConstantDataArray::get(
383 moduleTranslation.getLLVMContext(),
384 ArrayRef<char>{stringAttr.getValue().data(),
385 stringAttr.getValue().size()});
386 }
387 emitError(loc, "unsupported constant value");
388 return nullptr;
389 }
390
ModuleTranslation(Operation * module,std::unique_ptr<llvm::Module> llvmModule)391 ModuleTranslation::ModuleTranslation(Operation *module,
392 std::unique_ptr<llvm::Module> llvmModule)
393 : mlirModule(module), llvmModule(std::move(llvmModule)),
394 debugTranslation(
395 std::make_unique<DebugTranslation>(module, *this->llvmModule)),
396 typeTranslator(this->llvmModule->getContext()),
397 iface(module->getContext()) {
398 assert(satisfiesLLVMModule(mlirModule) &&
399 "mlirModule should honor LLVM's module semantics.");
400 }
~ModuleTranslation()401 ModuleTranslation::~ModuleTranslation() {
402 if (ompBuilder)
403 ompBuilder->finalize();
404 }
405
forgetMapping(Region & region)406 void ModuleTranslation::forgetMapping(Region ®ion) {
407 SmallVector<Region *> toProcess;
408 toProcess.push_back(®ion);
409 while (!toProcess.empty()) {
410 Region *current = toProcess.pop_back_val();
411 for (Block &block : *current) {
412 blockMapping.erase(&block);
413 for (Value arg : block.getArguments())
414 valueMapping.erase(arg);
415 for (Operation &op : block) {
416 for (Value value : op.getResults())
417 valueMapping.erase(value);
418 if (op.hasSuccessors())
419 branchMapping.erase(&op);
420 if (isa<LLVM::GlobalOp>(op))
421 globalsMapping.erase(&op);
422 accessGroupMetadataMapping.erase(&op);
423 llvm::append_range(
424 toProcess,
425 llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
426 }
427 }
428 }
429 }
430
431 /// Get the SSA value passed to the current block from the terminator operation
432 /// of its predecessor.
getPHISourceValue(Block * current,Block * pred,unsigned numArguments,unsigned index)433 static Value getPHISourceValue(Block *current, Block *pred,
434 unsigned numArguments, unsigned index) {
435 Operation &terminator = *pred->getTerminator();
436 if (isa<LLVM::BrOp>(terminator))
437 return terminator.getOperand(index);
438
439 #ifndef NDEBUG
440 llvm::SmallPtrSet<Block *, 4> seenSuccessors;
441 for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
442 Block *successor = terminator.getSuccessor(i);
443 auto branch = cast<BranchOpInterface>(terminator);
444 SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
445 assert(
446 (!seenSuccessors.contains(successor) || successorOperands.empty()) &&
447 "successors with arguments in LLVM branches must be different blocks");
448 seenSuccessors.insert(successor);
449 }
450 #endif
451
452 // For instructions that branch based on a condition value, we need to take
453 // the operands for the branch that was taken.
454 if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
455 // For conditional branches, we take the operands from either the "true" or
456 // the "false" branch.
457 return condBranchOp.getSuccessor(0) == current
458 ? condBranchOp.getTrueDestOperands()[index]
459 : condBranchOp.getFalseDestOperands()[index];
460 }
461
462 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
463 // For switches, we take the operands from either the default case, or from
464 // the case branch that was taken.
465 if (switchOp.getDefaultDestination() == current)
466 return switchOp.getDefaultOperands()[index];
467 for (const auto &i : llvm::enumerate(switchOp.getCaseDestinations()))
468 if (i.value() == current)
469 return switchOp.getCaseOperands(i.index())[index];
470 }
471
472 if (auto invokeOp = dyn_cast<LLVM::InvokeOp>(terminator)) {
473 return invokeOp.getNormalDest() == current
474 ? invokeOp.getNormalDestOperands()[index]
475 : invokeOp.getUnwindDestOperands()[index];
476 }
477
478 llvm_unreachable(
479 "only branch, switch or invoke operations can be terminators "
480 "of a block that has successors");
481 }
482
483 /// Connect the PHI nodes to the results of preceding blocks.
connectPHINodes(Region & region,const ModuleTranslation & state)484 void mlir::LLVM::detail::connectPHINodes(Region ®ion,
485 const ModuleTranslation &state) {
486 // Skip the first block, it cannot be branched to and its arguments correspond
487 // to the arguments of the LLVM function.
488 for (Block &bb : llvm::drop_begin(region)) {
489 llvm::BasicBlock *llvmBB = state.lookupBlock(&bb);
490 auto phis = llvmBB->phis();
491 auto numArguments = bb.getNumArguments();
492 assert(numArguments == std::distance(phis.begin(), phis.end()));
493 for (auto &numberedPhiNode : llvm::enumerate(phis)) {
494 auto &phiNode = numberedPhiNode.value();
495 unsigned index = numberedPhiNode.index();
496 for (auto *pred : bb.getPredecessors()) {
497 // Find the LLVM IR block that contains the converted terminator
498 // instruction and use it in the PHI node. Note that this block is not
499 // necessarily the same as state.lookupBlock(pred), some operations
500 // (in particular, OpenMP operations using OpenMPIRBuilder) may have
501 // split the blocks.
502 llvm::Instruction *terminator =
503 state.lookupBranch(pred->getTerminator());
504 assert(terminator && "missing the mapping for a terminator");
505 phiNode.addIncoming(state.lookupValue(getPHISourceValue(
506 &bb, pred, numArguments, index)),
507 terminator->getParent());
508 }
509 }
510 }
511 }
512
513 /// Sort function blocks topologically.
514 SetVector<Block *>
getTopologicallySortedBlocks(Region & region)515 mlir::LLVM::detail::getTopologicallySortedBlocks(Region ®ion) {
516 // For each block that has not been visited yet (i.e. that has no
517 // predecessors), add it to the list as well as its successors.
518 SetVector<Block *> blocks;
519 for (Block &b : region) {
520 if (blocks.count(&b) == 0) {
521 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
522 blocks.insert(traversal.begin(), traversal.end());
523 }
524 }
525 assert(blocks.size() == region.getBlocks().size() &&
526 "some blocks are not sorted");
527
528 return blocks;
529 }
530
createIntrinsicCall(llvm::IRBuilderBase & builder,llvm::Intrinsic::ID intrinsic,ArrayRef<llvm::Value * > args,ArrayRef<llvm::Type * > tys)531 llvm::Value *mlir::LLVM::detail::createIntrinsicCall(
532 llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
533 ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
534 llvm::Module *module = builder.GetInsertBlock()->getModule();
535 llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys);
536 return builder.CreateCall(fn, args);
537 }
538
539 /// Given a single MLIR operation, create the corresponding LLVM IR operation
540 /// using the `builder`.
541 LogicalResult
convertOperation(Operation & op,llvm::IRBuilderBase & builder)542 ModuleTranslation::convertOperation(Operation &op,
543 llvm::IRBuilderBase &builder) {
544 const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
545 if (!opIface)
546 return op.emitError("cannot be converted to LLVM IR: missing "
547 "`LLVMTranslationDialectInterface` registration for "
548 "dialect for op: ")
549 << op.getName();
550
551 if (failed(opIface->convertOperation(&op, builder, *this)))
552 return op.emitError("LLVM Translation failed for operation: ")
553 << op.getName();
554
555 return convertDialectAttributes(&op);
556 }
557
558 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
559 /// to define values corresponding to the MLIR block arguments. These nodes
560 /// are not connected to the source basic blocks, which may not exist yet. Uses
561 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
562 /// been created for `bb` and included in the block mapping. Inserts new
563 /// instructions at the end of the block and leaves `builder` in a state
564 /// suitable for further insertion into the end of the block.
convertBlock(Block & bb,bool ignoreArguments,llvm::IRBuilderBase & builder)565 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
566 llvm::IRBuilderBase &builder) {
567 builder.SetInsertPoint(lookupBlock(&bb));
568 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
569
570 // Before traversing operations, make block arguments available through
571 // value remapping and PHI nodes, but do not add incoming edges for the PHI
572 // nodes just yet: those values may be defined by this or following blocks.
573 // This step is omitted if "ignoreArguments" is set. The arguments of the
574 // first block have been already made available through the remapping of
575 // LLVM function arguments.
576 if (!ignoreArguments) {
577 auto predecessors = bb.getPredecessors();
578 unsigned numPredecessors =
579 std::distance(predecessors.begin(), predecessors.end());
580 for (auto arg : bb.getArguments()) {
581 auto wrappedType = arg.getType();
582 if (!isCompatibleType(wrappedType))
583 return emitError(bb.front().getLoc(),
584 "block argument does not have an LLVM type");
585 llvm::Type *type = convertType(wrappedType);
586 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
587 mapValue(arg, phi);
588 }
589 }
590
591 // Traverse operations.
592 for (auto &op : bb) {
593 // Set the current debug location within the builder.
594 builder.SetCurrentDebugLocation(
595 debugTranslation->translateLoc(op.getLoc(), subprogram));
596
597 if (failed(convertOperation(op, builder)))
598 return failure();
599 }
600
601 return success();
602 }
603
604 /// A helper method to get the single Block in an operation honoring LLVM's
605 /// module requirements.
getModuleBody(Operation * module)606 static Block &getModuleBody(Operation *module) {
607 return module->getRegion(0).front();
608 }
609
610 /// A helper method to decide if a constant must not be set as a global variable
611 /// initializer. For an external linkage variable, the variable with an
612 /// initializer is considered externally visible and defined in this module, the
613 /// variable without an initializer is externally available and is defined
614 /// elsewhere.
shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,llvm::Constant * cst)615 static bool shouldDropGlobalInitializer(llvm::GlobalValue::LinkageTypes linkage,
616 llvm::Constant *cst) {
617 return (linkage == llvm::GlobalVariable::ExternalLinkage && !cst) ||
618 linkage == llvm::GlobalVariable::ExternalWeakLinkage;
619 }
620
621 /// Sets the runtime preemption specifier of `gv` to dso_local if
622 /// `dsoLocalRequested` is true, otherwise it is left unchanged.
addRuntimePreemptionSpecifier(bool dsoLocalRequested,llvm::GlobalValue * gv)623 static void addRuntimePreemptionSpecifier(bool dsoLocalRequested,
624 llvm::GlobalValue *gv) {
625 if (dsoLocalRequested)
626 gv->setDSOLocal(true);
627 }
628
629 /// Create named global variables that correspond to llvm.mlir.global
630 /// definitions. Convert llvm.global_ctors and global_dtors ops.
convertGlobals()631 LogicalResult ModuleTranslation::convertGlobals() {
632 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
633 llvm::Type *type = convertType(op.getType());
634 llvm::Constant *cst = nullptr;
635 if (op.getValueOrNull()) {
636 // String attributes are treated separately because they cannot appear as
637 // in-function constants and are thus not supported by getLLVMConstant.
638 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
639 cst = llvm::ConstantDataArray::getString(
640 llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
641 type = cst->getType();
642 } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc(),
643 *this))) {
644 return failure();
645 }
646 }
647
648 auto linkage = convertLinkageToLLVM(op.getLinkage());
649 auto addrSpace = op.getAddrSpace();
650
651 // LLVM IR requires constant with linkage other than external or weak
652 // external to have initializers. If MLIR does not provide an initializer,
653 // default to undef.
654 bool dropInitializer = shouldDropGlobalInitializer(linkage, cst);
655 if (!dropInitializer && !cst)
656 cst = llvm::UndefValue::get(type);
657 else if (dropInitializer && cst)
658 cst = nullptr;
659
660 auto *var = new llvm::GlobalVariable(
661 *llvmModule, type, op.getConstant(), linkage, cst, op.getSymName(),
662 /*InsertBefore=*/nullptr,
663 op.getThreadLocal_() ? llvm::GlobalValue::GeneralDynamicTLSModel
664 : llvm::GlobalValue::NotThreadLocal,
665 addrSpace);
666
667 if (op.getUnnamedAddr().has_value())
668 var->setUnnamedAddr(convertUnnamedAddrToLLVM(*op.getUnnamedAddr()));
669
670 if (op.getSection().has_value())
671 var->setSection(*op.getSection());
672
673 addRuntimePreemptionSpecifier(op.getDsoLocal(), var);
674
675 Optional<uint64_t> alignment = op.getAlignment();
676 if (alignment.has_value())
677 var->setAlignment(llvm::MaybeAlign(alignment.value()));
678
679 globalsMapping.try_emplace(op, var);
680 }
681
682 // Convert global variable bodies. This is done after all global variables
683 // have been created in LLVM IR because a global body may refer to another
684 // global or itself. So all global variables need to be mapped first.
685 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
686 if (Block *initializer = op.getInitializerBlock()) {
687 llvm::IRBuilder<> builder(llvmModule->getContext());
688 for (auto &op : initializer->without_terminator()) {
689 if (failed(convertOperation(op, builder)) ||
690 !isa<llvm::Constant>(lookupValue(op.getResult(0))))
691 return emitError(op.getLoc(), "unemittable constant value");
692 }
693 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
694 llvm::Constant *cst =
695 cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
696 auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
697 if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
698 global->setInitializer(cst);
699 }
700 }
701
702 // Convert llvm.mlir.global_ctors and dtors.
703 for (Operation &op : getModuleBody(mlirModule)) {
704 auto ctorOp = dyn_cast<GlobalCtorsOp>(op);
705 auto dtorOp = dyn_cast<GlobalDtorsOp>(op);
706 if (!ctorOp && !dtorOp)
707 continue;
708 auto range = ctorOp ? llvm::zip(ctorOp.getCtors(), ctorOp.getPriorities())
709 : llvm::zip(dtorOp.getDtors(), dtorOp.getPriorities());
710 auto appendGlobalFn =
711 ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
712 for (auto symbolAndPriority : range) {
713 llvm::Function *f = lookupFunction(
714 std::get<0>(symbolAndPriority).cast<FlatSymbolRefAttr>().getValue());
715 appendGlobalFn(
716 *llvmModule, f,
717 std::get<1>(symbolAndPriority).cast<IntegerAttr>().getInt(),
718 /*Data=*/nullptr);
719 }
720 }
721
722 return success();
723 }
724
725 /// Attempts to add an attribute identified by `key`, optionally with the given
726 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
727 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
728 /// otherwise keep it as a string attribute. Performs additional checks for
729 /// attributes known to have or not have a value in order to avoid assertions
730 /// inside LLVM upon construction.
checkedAddLLVMFnAttribute(Location loc,llvm::Function * llvmFunc,StringRef key,StringRef value=StringRef ())731 static LogicalResult checkedAddLLVMFnAttribute(Location loc,
732 llvm::Function *llvmFunc,
733 StringRef key,
734 StringRef value = StringRef()) {
735 auto kind = llvm::Attribute::getAttrKindFromName(key);
736 if (kind == llvm::Attribute::None) {
737 llvmFunc->addFnAttr(key, value);
738 return success();
739 }
740
741 if (llvm::Attribute::isIntAttrKind(kind)) {
742 if (value.empty())
743 return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
744
745 int result;
746 if (!value.getAsInteger(/*Radix=*/0, result))
747 llvmFunc->addFnAttr(
748 llvm::Attribute::get(llvmFunc->getContext(), kind, result));
749 else
750 llvmFunc->addFnAttr(key, value);
751 return success();
752 }
753
754 if (!value.empty())
755 return emitError(loc) << "LLVM attribute '" << key
756 << "' does not expect a value, found '" << value
757 << "'";
758
759 llvmFunc->addFnAttr(kind);
760 return success();
761 }
762
763 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
764 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
765 /// to be an array attribute containing either string attributes, treated as
766 /// value-less LLVM attributes, or array attributes containing two string
767 /// attributes, with the first string being the name of the corresponding LLVM
768 /// attribute and the second string beings its value. Note that even integer
769 /// attributes are expected to have their values expressed as strings.
770 static LogicalResult
forwardPassthroughAttributes(Location loc,Optional<ArrayAttr> attributes,llvm::Function * llvmFunc)771 forwardPassthroughAttributes(Location loc, Optional<ArrayAttr> attributes,
772 llvm::Function *llvmFunc) {
773 if (!attributes)
774 return success();
775
776 for (Attribute attr : *attributes) {
777 if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
778 if (failed(
779 checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
780 return failure();
781 continue;
782 }
783
784 auto arrayAttr = attr.dyn_cast<ArrayAttr>();
785 if (!arrayAttr || arrayAttr.size() != 2)
786 return emitError(loc)
787 << "expected 'passthrough' to contain string or array attributes";
788
789 auto keyAttr = arrayAttr[0].dyn_cast<StringAttr>();
790 auto valueAttr = arrayAttr[1].dyn_cast<StringAttr>();
791 if (!keyAttr || !valueAttr)
792 return emitError(loc)
793 << "expected arrays within 'passthrough' to contain two strings";
794
795 if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
796 valueAttr.getValue())))
797 return failure();
798 }
799 return success();
800 }
801
convertOneFunction(LLVMFuncOp func)802 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
803 // Clear the block, branch value mappings, they are only relevant within one
804 // function.
805 blockMapping.clear();
806 valueMapping.clear();
807 branchMapping.clear();
808 llvm::Function *llvmFunc = lookupFunction(func.getName());
809
810 // Translate the debug information for this function.
811 debugTranslation->translate(func, *llvmFunc);
812
813 // Add function arguments to the value remapping table.
814 // If there was noalias info then we decorate each argument accordingly.
815 unsigned int argIdx = 0;
816 for (auto kvp : llvm::zip(func.getArguments(), llvmFunc->args())) {
817 llvm::Argument &llvmArg = std::get<1>(kvp);
818 BlockArgument mlirArg = std::get<0>(kvp);
819
820 if (auto attr = func.getArgAttrOfType<UnitAttr>(
821 argIdx, LLVMDialect::getNoAliasAttrName())) {
822 // NB: Attribute already verified to be boolean, so check if we can indeed
823 // attach the attribute to this argument, based on its type.
824 auto argTy = mlirArg.getType();
825 if (!argTy.isa<LLVM::LLVMPointerType>())
826 return func.emitError(
827 "llvm.noalias attribute attached to LLVM non-pointer argument");
828 llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
829 }
830
831 if (auto attr = func.getArgAttrOfType<IntegerAttr>(
832 argIdx, LLVMDialect::getAlignAttrName())) {
833 // NB: Attribute already verified to be int, so check if we can indeed
834 // attach the attribute to this argument, based on its type.
835 auto argTy = mlirArg.getType();
836 if (!argTy.isa<LLVM::LLVMPointerType>())
837 return func.emitError(
838 "llvm.align attribute attached to LLVM non-pointer argument");
839 llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
840 .addAlignmentAttr(llvm::Align(attr.getInt())));
841 }
842
843 if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.sret")) {
844 auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMPointerType>();
845 if (!argTy)
846 return func.emitError(
847 "llvm.sret attribute attached to LLVM non-pointer argument");
848 llvmArg.addAttrs(
849 llvm::AttrBuilder(llvmArg.getContext())
850 .addStructRetAttr(convertType(argTy.getElementType())));
851 }
852
853 if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.byval")) {
854 auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMPointerType>();
855 if (!argTy)
856 return func.emitError(
857 "llvm.byval attribute attached to LLVM non-pointer argument");
858 llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
859 .addByValAttr(convertType(argTy.getElementType())));
860 }
861
862 if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.nest")) {
863 auto argTy = mlirArg.getType();
864 if (!argTy.isa<LLVM::LLVMPointerType>())
865 return func.emitError(
866 "llvm.nest attribute attached to LLVM non-pointer argument");
867 llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
868 .addAttribute(llvm::Attribute::Nest));
869 }
870
871 mapValue(mlirArg, &llvmArg);
872 argIdx++;
873 }
874
875 // Check the personality and set it.
876 if (func.getPersonality()) {
877 llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext());
878 if (llvm::Constant *pfunc = getLLVMConstant(ty, func.getPersonalityAttr(),
879 func.getLoc(), *this))
880 llvmFunc->setPersonalityFn(pfunc);
881 }
882
883 if (auto gc = func.getGarbageCollector())
884 llvmFunc->setGC(gc->str());
885
886 // First, create all blocks so we can jump to them.
887 llvm::LLVMContext &llvmContext = llvmFunc->getContext();
888 for (auto &bb : func) {
889 auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
890 llvmBB->insertInto(llvmFunc);
891 mapBlock(&bb, llvmBB);
892 }
893
894 // Then, convert blocks one by one in topological order to ensure defs are
895 // converted before uses.
896 auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
897 for (Block *bb : blocks) {
898 llvm::IRBuilder<> builder(llvmContext);
899 if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
900 return failure();
901 }
902
903 // After all blocks have been traversed and values mapped, connect the PHI
904 // nodes to the results of preceding blocks.
905 detail::connectPHINodes(func.getBody(), *this);
906
907 // Finally, convert dialect attributes attached to the function.
908 return convertDialectAttributes(func);
909 }
910
convertDialectAttributes(Operation * op)911 LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
912 for (NamedAttribute attribute : op->getDialectAttrs())
913 if (failed(iface.amendOperation(op, attribute, *this)))
914 return failure();
915 return success();
916 }
917
convertFunctionSignatures()918 LogicalResult ModuleTranslation::convertFunctionSignatures() {
919 // Declare all functions first because there may be function calls that form a
920 // call graph with cycles, or global initializers that reference functions.
921 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
922 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
923 function.getName(),
924 cast<llvm::FunctionType>(convertType(function.getFunctionType())));
925 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
926 llvmFunc->setLinkage(convertLinkageToLLVM(function.getLinkage()));
927 mapFunction(function.getName(), llvmFunc);
928 addRuntimePreemptionSpecifier(function.getDsoLocal(), llvmFunc);
929
930 // Forward the pass-through attributes to LLVM.
931 if (failed(forwardPassthroughAttributes(
932 function.getLoc(), function.getPassthrough(), llvmFunc)))
933 return failure();
934 }
935
936 return success();
937 }
938
convertFunctions()939 LogicalResult ModuleTranslation::convertFunctions() {
940 // Convert functions.
941 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
942 // Ignore external functions.
943 if (function.isExternal())
944 continue;
945
946 if (failed(convertOneFunction(function)))
947 return failure();
948 }
949
950 return success();
951 }
952
953 llvm::MDNode *
getAccessGroup(Operation & opInst,SymbolRefAttr accessGroupRef) const954 ModuleTranslation::getAccessGroup(Operation &opInst,
955 SymbolRefAttr accessGroupRef) const {
956 auto metadataName = accessGroupRef.getRootReference();
957 auto accessGroupName = accessGroupRef.getLeafReference();
958 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
959 opInst.getParentOp(), metadataName);
960 auto *accessGroupOp =
961 SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
962 return accessGroupMetadataMapping.lookup(accessGroupOp);
963 }
964
createAccessGroupMetadata()965 LogicalResult ModuleTranslation::createAccessGroupMetadata() {
966 mlirModule->walk([&](LLVM::MetadataOp metadatas) {
967 metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
968 llvm::LLVMContext &ctx = llvmModule->getContext();
969 llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {});
970 accessGroupMetadataMapping.insert({op, accessGroup});
971 });
972 });
973 return success();
974 }
975
setAccessGroupsMetadata(Operation * op,llvm::Instruction * inst)976 void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
977 llvm::Instruction *inst) {
978 auto accessGroups =
979 op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
980 if (accessGroups && !accessGroups.empty()) {
981 llvm::Module *module = inst->getModule();
982 SmallVector<llvm::Metadata *> metadatas;
983 for (SymbolRefAttr accessGroupRef :
984 accessGroups.getAsRange<SymbolRefAttr>())
985 metadatas.push_back(getAccessGroup(*op, accessGroupRef));
986
987 llvm::MDNode *unionMD = nullptr;
988 if (metadatas.size() == 1)
989 unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
990 else if (metadatas.size() >= 2)
991 unionMD = llvm::MDNode::get(module->getContext(), metadatas);
992
993 inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
994 }
995 }
996
createAliasScopeMetadata()997 LogicalResult ModuleTranslation::createAliasScopeMetadata() {
998 mlirModule->walk([&](LLVM::MetadataOp metadatas) {
999 // Create the domains first, so they can be reference below in the scopes.
1000 DenseMap<Operation *, llvm::MDNode *> aliasScopeDomainMetadataMapping;
1001 metadatas.walk([&](LLVM::AliasScopeDomainMetadataOp op) {
1002 llvm::LLVMContext &ctx = llvmModule->getContext();
1003 llvm::SmallVector<llvm::Metadata *, 2> operands;
1004 operands.push_back({}); // Placeholder for self-reference
1005 if (Optional<StringRef> description = op.getDescription())
1006 operands.push_back(llvm::MDString::get(ctx, *description));
1007 llvm::MDNode *domain = llvm::MDNode::get(ctx, operands);
1008 domain->replaceOperandWith(0, domain); // Self-reference for uniqueness
1009 aliasScopeDomainMetadataMapping.insert({op, domain});
1010 });
1011
1012 // Now create the scopes, referencing the domains created above.
1013 metadatas.walk([&](LLVM::AliasScopeMetadataOp op) {
1014 llvm::LLVMContext &ctx = llvmModule->getContext();
1015 assert(isa<LLVM::MetadataOp>(op->getParentOp()));
1016 auto metadataOp = dyn_cast<LLVM::MetadataOp>(op->getParentOp());
1017 Operation *domainOp =
1018 SymbolTable::lookupNearestSymbolFrom(metadataOp, op.getDomainAttr());
1019 llvm::MDNode *domain = aliasScopeDomainMetadataMapping.lookup(domainOp);
1020 assert(domain && "Scope's domain should already be valid");
1021 llvm::SmallVector<llvm::Metadata *, 3> operands;
1022 operands.push_back({}); // Placeholder for self-reference
1023 operands.push_back(domain);
1024 if (Optional<StringRef> description = op.getDescription())
1025 operands.push_back(llvm::MDString::get(ctx, *description));
1026 llvm::MDNode *scope = llvm::MDNode::get(ctx, operands);
1027 scope->replaceOperandWith(0, scope); // Self-reference for uniqueness
1028 aliasScopeMetadataMapping.insert({op, scope});
1029 });
1030 });
1031 return success();
1032 }
1033
1034 llvm::MDNode *
getAliasScope(Operation & opInst,SymbolRefAttr aliasScopeRef) const1035 ModuleTranslation::getAliasScope(Operation &opInst,
1036 SymbolRefAttr aliasScopeRef) const {
1037 StringAttr metadataName = aliasScopeRef.getRootReference();
1038 StringAttr scopeName = aliasScopeRef.getLeafReference();
1039 auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
1040 opInst.getParentOp(), metadataName);
1041 Operation *aliasScopeOp =
1042 SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName);
1043 return aliasScopeMetadataMapping.lookup(aliasScopeOp);
1044 }
1045
setAliasScopeMetadata(Operation * op,llvm::Instruction * inst)1046 void ModuleTranslation::setAliasScopeMetadata(Operation *op,
1047 llvm::Instruction *inst) {
1048 auto populateScopeMetadata = [this, op, inst](StringRef attrName,
1049 StringRef llvmMetadataName) {
1050 auto scopes = op->getAttrOfType<ArrayAttr>(attrName);
1051 if (!scopes || scopes.empty())
1052 return;
1053 llvm::Module *module = inst->getModule();
1054 SmallVector<llvm::Metadata *> scopeMDs;
1055 for (SymbolRefAttr scopeRef : scopes.getAsRange<SymbolRefAttr>())
1056 scopeMDs.push_back(getAliasScope(*op, scopeRef));
1057 llvm::MDNode *unionMD = llvm::MDNode::get(module->getContext(), scopeMDs);
1058 inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD);
1059 };
1060
1061 populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope");
1062 populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias");
1063 }
1064
convertType(Type type)1065 llvm::Type *ModuleTranslation::convertType(Type type) {
1066 return typeTranslator.translateType(type);
1067 }
1068
1069 /// A helper to look up remapped operands in the value remapping table.
lookupValues(ValueRange values)1070 SmallVector<llvm::Value *> ModuleTranslation::lookupValues(ValueRange values) {
1071 SmallVector<llvm::Value *> remapped;
1072 remapped.reserve(values.size());
1073 for (Value v : values)
1074 remapped.push_back(lookupValue(v));
1075 return remapped;
1076 }
1077
1078 const llvm::DILocation *
translateLoc(Location loc,llvm::DILocalScope * scope)1079 ModuleTranslation::translateLoc(Location loc, llvm::DILocalScope *scope) {
1080 return debugTranslation->translateLoc(loc, scope);
1081 }
1082
1083 llvm::NamedMDNode *
getOrInsertNamedModuleMetadata(StringRef name)1084 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
1085 return llvmModule->getOrInsertNamedMetadata(name);
1086 }
1087
anchor()1088 void ModuleTranslation::StackFrame::anchor() {}
1089
1090 static std::unique_ptr<llvm::Module>
prepareLLVMModule(Operation * m,llvm::LLVMContext & llvmContext,StringRef name)1091 prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
1092 StringRef name) {
1093 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
1094 auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1095 if (auto dataLayoutAttr =
1096 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
1097 llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue());
1098 } else {
1099 FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
1100 if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
1101 if (DataLayoutSpecInterface spec = iface.getDataLayoutSpec()) {
1102 llvmDataLayout =
1103 translateDataLayout(spec, DataLayout(iface), m->getLoc());
1104 }
1105 } else if (auto mod = dyn_cast<ModuleOp>(m)) {
1106 if (DataLayoutSpecInterface spec = mod.getDataLayoutSpec()) {
1107 llvmDataLayout =
1108 translateDataLayout(spec, DataLayout(mod), m->getLoc());
1109 }
1110 }
1111 if (failed(llvmDataLayout))
1112 return nullptr;
1113 llvmModule->setDataLayout(*llvmDataLayout);
1114 }
1115 if (auto targetTripleAttr =
1116 m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1117 llvmModule->setTargetTriple(targetTripleAttr.cast<StringAttr>().getValue());
1118
1119 // Inject declarations for `malloc` and `free` functions that can be used in
1120 // memref allocation/deallocation coming from standard ops lowering.
1121 llvm::IRBuilder<> builder(llvmContext);
1122 llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
1123 builder.getInt64Ty());
1124 llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
1125 builder.getInt8PtrTy());
1126
1127 return llvmModule;
1128 }
1129
1130 std::unique_ptr<llvm::Module>
translateModuleToLLVMIR(Operation * module,llvm::LLVMContext & llvmContext,StringRef name)1131 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
1132 StringRef name) {
1133 if (!satisfiesLLVMModule(module))
1134 return nullptr;
1135
1136 std::unique_ptr<llvm::Module> llvmModule =
1137 prepareLLVMModule(module, llvmContext, name);
1138 if (!llvmModule)
1139 return nullptr;
1140
1141 LLVM::ensureDistinctSuccessors(module);
1142
1143 ModuleTranslation translator(module, std::move(llvmModule));
1144 if (failed(translator.convertFunctionSignatures()))
1145 return nullptr;
1146 if (failed(translator.convertGlobals()))
1147 return nullptr;
1148 if (failed(translator.createAccessGroupMetadata()))
1149 return nullptr;
1150 if (failed(translator.createAliasScopeMetadata()))
1151 return nullptr;
1152 if (failed(translator.convertFunctions()))
1153 return nullptr;
1154
1155 // Convert other top-level operations if possible.
1156 llvm::IRBuilder<> llvmBuilder(llvmContext);
1157 for (Operation &o : getModuleBody(module).getOperations()) {
1158 if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
1159 LLVM::GlobalDtorsOp, LLVM::MetadataOp>(&o) &&
1160 !o.hasTrait<OpTrait::IsTerminator>() &&
1161 failed(translator.convertOperation(o, llvmBuilder))) {
1162 return nullptr;
1163 }
1164 }
1165
1166 if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
1167 return nullptr;
1168
1169 return std::move(translator.llvmModule);
1170 }
1171