1 //===- TestConvertCallOp.cpp - Test LLVM Conversion of Func CallOp --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 class TestTypeProducerOpConverter
22     : public ConvertOpToLLVMPattern<test::TestTypeProducerOp> {
23 public:
24   using ConvertOpToLLVMPattern<
25       test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
26 
27   LogicalResult
28   matchAndRewrite(test::TestTypeProducerOp op, OpAdaptor adaptor,
29                   ConversionPatternRewriter &rewriter) const override {
30     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
31     return success();
32   }
33 };
34 
35 class TestConvertCallOp
36     : public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
37 public:
38   void getDependentDialects(DialectRegistry &registry) const final {
39     registry.insert<LLVM::LLVMDialect>();
40   }
41   StringRef getArgument() const final { return "test-convert-call-op"; }
42   StringRef getDescription() const final {
43     return "Tests conversion of `func.call` to `llvm.call` in "
44            "presence of custom types";
45   }
46 
47   void runOnOperation() override {
48     ModuleOp m = getOperation();
49 
50     // Populate type conversions.
51     LLVMTypeConverter typeConverter(m.getContext());
52     typeConverter.addConversion([&](test::TestType type) {
53       return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
54     });
55     typeConverter.addConversion([&](test::SimpleAType type) {
56       return IntegerType::get(type.getContext(), 42);
57     });
58 
59     // Populate patterns.
60     RewritePatternSet patterns(m.getContext());
61     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
62     patterns.add<TestTypeProducerOpConverter>(typeConverter);
63 
64     // Set target.
65     ConversionTarget target(getContext());
66     target.addLegalDialect<LLVM::LLVMDialect>();
67     target.addIllegalDialect<test::TestDialect>();
68     target.addIllegalDialect<func::FuncDialect>();
69 
70     if (failed(applyPartialConversion(m, target, std::move(patterns))))
71       signalPassFailure();
72   }
73 };
74 
75 } // namespace
76 
77 namespace mlir {
78 namespace test {
79 void registerConvertCallOpPass() { PassRegistration<TestConvertCallOp>(); }
80 } // namespace test
81 } // namespace mlir
82