1 //===- DoLoopHelper.cpp -- DoLoopHelper unit tests ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/DoLoopHelper.h" 10 #include "gtest/gtest.h" 11 #include "flang/Optimizer/Support/InitFIR.h" 12 #include "flang/Optimizer/Support/KindMapping.h" 13 #include <string> 14 15 using namespace mlir; 16 17 struct DoLoopHelperTest : public testing::Test { 18 public: 19 void SetUp() { 20 kindMap = std::make_unique<fir::KindMapping>(&context); 21 mlir::OpBuilder builder(&context); 22 firBuilder = new fir::FirOpBuilder(builder, *kindMap); 23 fir::support::loadDialects(context); 24 } 25 void TearDown() { delete firBuilder; } 26 27 fir::FirOpBuilder &getBuilder() { return *firBuilder; } 28 29 mlir::MLIRContext context; 30 std::unique_ptr<fir::KindMapping> kindMap; 31 fir::FirOpBuilder *firBuilder; 32 }; 33 34 void checkConstantValue(const mlir::Value &value, int64_t v) { 35 EXPECT_TRUE(mlir::isa<mlir::arith::ConstantOp>(value.getDefiningOp())); 36 auto cstOp = dyn_cast<mlir::arith::ConstantOp>(value.getDefiningOp()); 37 auto valueAttr = cstOp.getValue().dyn_cast_or_null<IntegerAttr>(); 38 EXPECT_EQ(v, valueAttr.getInt()); 39 } 40 41 TEST_F(DoLoopHelperTest, createLoopWithCountTest) { 42 auto firBuilder = getBuilder(); 43 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc()); 44 45 auto c10 = firBuilder.createIntegerConstant( 46 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 10); 47 auto loop = 48 helper.createLoop(c10, [&](fir::FirOpBuilder &, mlir::Value index) {}); 49 checkConstantValue(loop.getLowerBound(), 0); 50 EXPECT_TRUE(mlir::isa<arith::SubIOp>(loop.getUpperBound().getDefiningOp())); 51 auto subOp = dyn_cast<arith::SubIOp>(loop.getUpperBound().getDefiningOp()); 52 EXPECT_EQ(c10, subOp.getLhs()); 53 checkConstantValue(subOp.getRhs(), 1); 54 checkConstantValue(loop.getStep(), 1); 55 } 56 57 TEST_F(DoLoopHelperTest, createLoopWithLowerAndUpperBound) { 58 auto firBuilder = getBuilder(); 59 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc()); 60 61 auto lb = firBuilder.createIntegerConstant( 62 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1); 63 auto ub = firBuilder.createIntegerConstant( 64 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20); 65 auto loop = 66 helper.createLoop(lb, ub, [&](fir::FirOpBuilder &, mlir::Value index) {}); 67 checkConstantValue(loop.getLowerBound(), 1); 68 checkConstantValue(loop.getUpperBound(), 20); 69 checkConstantValue(loop.getStep(), 1); 70 } 71 72 TEST_F(DoLoopHelperTest, createLoopWithStep) { 73 auto firBuilder = getBuilder(); 74 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc()); 75 76 auto lb = firBuilder.createIntegerConstant( 77 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1); 78 auto ub = firBuilder.createIntegerConstant( 79 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20); 80 auto step = firBuilder.createIntegerConstant( 81 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 2); 82 auto loop = helper.createLoop( 83 lb, ub, step, [&](fir::FirOpBuilder &, mlir::Value index) {}); 84 checkConstantValue(loop.getLowerBound(), 1); 85 checkConstantValue(loop.getUpperBound(), 20); 86 checkConstantValue(loop.getStep(), 2); 87 } 88