1 //===- DoLoopHelper.cpp -- DoLoopHelper unit tests ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Builder/DoLoopHelper.h"
10 #include "gtest/gtest.h"
11 #include "flang/Optimizer/Support/InitFIR.h"
12 #include "flang/Optimizer/Support/KindMapping.h"
13 #include <string>
14 
15 struct DoLoopHelperTest : public testing::Test {
16 public:
17   void SetUp() {
18     fir::KindMapping kindMap(&context);
19     mlir::OpBuilder builder(&context);
20     firBuilder = new fir::FirOpBuilder(builder, kindMap);
21     fir::support::loadDialects(context);
22   }
23   void TearDown() { delete firBuilder; }
24 
25   fir::FirOpBuilder &getBuilder() { return *firBuilder; }
26 
27   mlir::MLIRContext context;
28   fir::FirOpBuilder *firBuilder;
29 };
30 
31 void checkConstantValue(const mlir::Value &value, int64_t v) {
32   EXPECT_TRUE(mlir::isa<ConstantOp>(value.getDefiningOp()));
33   auto cstOp = dyn_cast<ConstantOp>(value.getDefiningOp());
34   auto valueAttr = cstOp.getValue().dyn_cast_or_null<IntegerAttr>();
35   EXPECT_EQ(v, valueAttr.getInt());
36 }
37 
38 TEST_F(DoLoopHelperTest, createLoopWithCountTest) {
39   auto firBuilder = getBuilder();
40   fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
41 
42   auto c10 = firBuilder.createIntegerConstant(
43       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 10);
44   auto loop =
45       helper.createLoop(c10, [&](fir::FirOpBuilder &, mlir::Value index) {});
46   checkConstantValue(loop.lowerBound(), 0);
47   EXPECT_TRUE(mlir::isa<arith::SubIOp>(loop.upperBound().getDefiningOp()));
48   auto subOp = dyn_cast<arith::SubIOp>(loop.upperBound().getDefiningOp());
49   EXPECT_EQ(c10, subOp.lhs());
50   checkConstantValue(subOp.rhs(), 1);
51   checkConstantValue(loop.step(), 1);
52 }
53 
54 TEST_F(DoLoopHelperTest, createLoopWithLowerAndUpperBound) {
55   auto firBuilder = getBuilder();
56   fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
57 
58   auto lb = firBuilder.createIntegerConstant(
59       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
60   auto ub = firBuilder.createIntegerConstant(
61       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
62   auto loop =
63       helper.createLoop(lb, ub, [&](fir::FirOpBuilder &, mlir::Value index) {});
64   checkConstantValue(loop.lowerBound(), 1);
65   checkConstantValue(loop.upperBound(), 20);
66   checkConstantValue(loop.step(), 1);
67 }
68 
69 TEST_F(DoLoopHelperTest, createLoopWithStep) {
70   auto firBuilder = getBuilder();
71   fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc());
72 
73   auto lb = firBuilder.createIntegerConstant(
74       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1);
75   auto ub = firBuilder.createIntegerConstant(
76       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20);
77   auto step = firBuilder.createIntegerConstant(
78       firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 2);
79   auto loop = helper.createLoop(
80       lb, ub, step, [&](fir::FirOpBuilder &, mlir::Value index) {});
81   checkConstantValue(loop.lowerBound(), 1);
82   checkConstantValue(loop.upperBound(), 20);
83   checkConstantValue(loop.step(), 2);
84 }
85