1 //===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/Complex.h"
10 #include "gtest/gtest.h"
11 #include "flang/Optimizer/Builder/FIRBuilder.h"
12 #include "flang/Optimizer/Support/InitFIR.h"
13 #include "flang/Optimizer/Support/KindMapping.h"
14 
15 struct ComplexTest : public testing::Test {
16 public:
17   void SetUp() override {
18     mlir::OpBuilder builder(&context);
19     auto loc = builder.getUnknownLoc();
20 
21     // Set up a Module with a dummy function operation inside.
22     // Set the insertion point in the function entry block.
23     mlir::ModuleOp mod = builder.create<mlir::ModuleOp>(loc);
24     mlir::FuncOp func = mlir::FuncOp::create(
25         loc, "func1", builder.getFunctionType(llvm::None, llvm::None));
26     auto *entryBlock = func.addEntryBlock();
27     mod.push_back(mod);
28     builder.setInsertionPointToStart(entryBlock);
29 
30     fir::support::loadDialects(context);
31     kindMap = std::make_unique<fir::KindMapping>(&context);
32     firBuilder = std::make_unique<fir::FirOpBuilder>(mod, *kindMap);
33     helper = std::make_unique<fir::factory::Complex>(*firBuilder, loc);
34 
35     // Init commonly used types
36     realTy1 = mlir::FloatType::getF32(&context);
37     complexTy1 = fir::ComplexType::get(&context, 4);
38     integerTy1 = mlir::IntegerType::get(&context, 32);
39 
40     // Create commonly used reals
41     rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
42     rTwo = firBuilder->createRealConstant(loc, realTy1, 2u);
43     rThree = firBuilder->createRealConstant(loc, realTy1, 3u);
44     rFour = firBuilder->createRealConstant(loc, realTy1, 4u);
45   }
46 
47   mlir::MLIRContext context;
48   std::unique_ptr<fir::KindMapping> kindMap;
49   std::unique_ptr<fir::FirOpBuilder> firBuilder;
50   std::unique_ptr<fir::factory::Complex> helper;
51 
52   // Commonly used real/complex/integer types
53   mlir::FloatType realTy1;
54   fir::ComplexType complexTy1;
55   mlir::IntegerType integerTy1;
56 
57   // Commonly used real numbers
58   mlir::Value rOne;
59   mlir::Value rTwo;
60   mlir::Value rThree;
61   mlir::Value rFour;
62 };
63 
64 TEST_F(ComplexTest, verifyTypes) {
65   mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo);
66   mlir::Value cVal2 = helper->createComplex(4, rOne, rTwo);
67   EXPECT_TRUE(fir::isa_complex(cVal1.getType()));
68   EXPECT_TRUE(fir::isa_complex(cVal2.getType()));
69   EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1)));
70   EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal2)));
71 
72   mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false);
73   mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true);
74   mlir::Value real2 = helper->extractComplexPart(cVal2, /*isImagPart=*/false);
75   mlir::Value imag2 = helper->extractComplexPart(cVal2, /*isImagPart=*/true);
76   EXPECT_EQ(realTy1, real1.getType());
77   EXPECT_EQ(realTy1, imag1.getType());
78   EXPECT_EQ(realTy1, real2.getType());
79   EXPECT_EQ(realTy1, imag2.getType());
80 
81   mlir::Value cVal3 =
82       helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false);
83   mlir::Value cVal4 =
84       helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true);
85   EXPECT_TRUE(fir::isa_complex(cVal4.getType()));
86   EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4)));
87 }
88 
89 TEST_F(ComplexTest, verifyConvertWithSemantics) {
90   auto loc = firBuilder->getUnknownLoc();
91   rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
92   // Convert real to complex
93   mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne);
94   EXPECT_TRUE(fir::isa_complex(v1.getType()));
95 
96   // Convert complex to integer
97   mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1);
98   EXPECT_TRUE(v2.getType().isa<mlir::IntegerType>());
99   EXPECT_TRUE(mlir::dyn_cast<fir::ConvertOp>(v2.getDefiningOp()));
100 }
101