1 //===- DoLoopHelper.cpp -- DoLoopHelper unit tests ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/DoLoopHelper.h"
10 #include "gtest/gtest.h"
11 #include "flang/Optimizer/Support/InitFIR.h"
12 #include "flang/Optimizer/Support/KindMapping.h"
13 #include <string>
14 
15 struct DoLoopHelperTest : public testing::Test {
16 public:
17   void SetUp() {
18     kindMap = std::make_unique<fir::KindMapping>(&context);
19     mlir::OpBuilder builder(&context);
20     firBuilder = new fir::FirOpBuilder(builder, *kindMap);
21     fir::support::loadDialects(context);
22   }
23   void TearDown() { delete firBuilder; }
24 
25   fir::FirOpBuilder &getBuilder() { return *firBuilder; }
26 
27   mlir::MLIRContext context;
28   std::unique_ptr<fir::KindMapping> kindMap;
29   fir::FirOpBuilder *firBuilder;
30 };
31 
32 void checkConstantValue(const mlir::Value &value, int64_t v) {
33   EXPECT_TRUE(mlir::isa<mlir::arith::ConstantOp>(value.getDefiningOp()));
34   auto cstOp = dyn_cast<mlir::arith::ConstantOp>(value.getDefiningOp());
35   auto valueAttr = cstOp.getValue().dyn_cast_or_null<IntegerAttr>();
36   EXPECT_EQ(v, valueAttr.getInt());
37 }
38 
39 TEST_F(DoLoopHelperTest, createLoopWithCountTest) {
40   auto firBuilder = getBuilder();
41   fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
42 
43   auto c10 = firBuilder.createIntegerConstant(
44       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 10);
45   auto loop =
46       helper.createLoop(c10, [&](fir::FirOpBuilder &, mlir::Value index) {});
47   checkConstantValue(loop.getLowerBound(), 0);
48   EXPECT_TRUE(mlir::isa<arith::SubIOp>(loop.getUpperBound().getDefiningOp()));
49   auto subOp = dyn_cast<arith::SubIOp>(loop.getUpperBound().getDefiningOp());
50   EXPECT_EQ(c10, subOp.getLhs());
51   checkConstantValue(subOp.getRhs(), 1);
52   checkConstantValue(loop.getStep(), 1);
53 }
54 
55 TEST_F(DoLoopHelperTest, createLoopWithLowerAndUpperBound) {
56   auto firBuilder = getBuilder();
57   fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
58 
59   auto lb = firBuilder.createIntegerConstant(
60       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
61   auto ub = firBuilder.createIntegerConstant(
62       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
63   auto loop =
64       helper.createLoop(lb, ub, [&](fir::FirOpBuilder &, mlir::Value index) {});
65   checkConstantValue(loop.getLowerBound(), 1);
66   checkConstantValue(loop.getUpperBound(), 20);
67   checkConstantValue(loop.getStep(), 1);
68 }
69 
70 TEST_F(DoLoopHelperTest, createLoopWithStep) {
71   auto firBuilder = getBuilder();
72   fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
73 
74   auto lb = firBuilder.createIntegerConstant(
75       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
76   auto ub = firBuilder.createIntegerConstant(
77       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
78   auto step = firBuilder.createIntegerConstant(
79       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 2);
80   auto loop = helper.createLoop(
81       lb, ub, step, [&](fir::FirOpBuilder &, mlir::Value index) {});
82   checkConstantValue(loop.getLowerBound(), 1);
83   checkConstantValue(loop.getUpperBound(), 20);
84   checkConstantValue(loop.getStep(), 2);
85 }
86