1 //===- DoLoopHelper.cpp -- DoLoopHelper unit tests ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/DoLoopHelper.h" 10 #include "gtest/gtest.h" 11 #include "flang/Optimizer/Support/InitFIR.h" 12 #include "flang/Optimizer/Support/KindMapping.h" 13 #include <string> 14 15 struct DoLoopHelperTest : public testing::Test { 16 public: 17 void SetUp() { 18 kindMap = std::make_unique<fir::KindMapping>(&context); 19 mlir::OpBuilder builder(&context); 20 firBuilder = new fir::FirOpBuilder(builder, *kindMap); 21 fir::support::loadDialects(context); 22 } 23 void TearDown() { delete firBuilder; } 24 25 fir::FirOpBuilder &getBuilder() { return *firBuilder; } 26 27 mlir::MLIRContext context; 28 std::unique_ptr<fir::KindMapping> kindMap; 29 fir::FirOpBuilder *firBuilder; 30 }; 31 32 void checkConstantValue(const mlir::Value &value, int64_t v) { 33 EXPECT_TRUE(mlir::isa<mlir::arith::ConstantOp>(value.getDefiningOp())); 34 auto cstOp = dyn_cast<mlir::arith::ConstantOp>(value.getDefiningOp()); 35 auto valueAttr = cstOp.getValue().dyn_cast_or_null<IntegerAttr>(); 36 EXPECT_EQ(v, valueAttr.getInt()); 37 } 38 39 TEST_F(DoLoopHelperTest, createLoopWithCountTest) { 40 auto firBuilder = getBuilder(); 41 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc()); 42 43 auto c10 = firBuilder.createIntegerConstant( 44 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 10); 45 auto loop = 46 helper.createLoop(c10, [&](fir::FirOpBuilder &, mlir::Value index) {}); 47 checkConstantValue(loop.getLowerBound(), 0); 48 EXPECT_TRUE(mlir::isa<arith::SubIOp>(loop.getUpperBound().getDefiningOp())); 49 auto subOp = dyn_cast<arith::SubIOp>(loop.getUpperBound().getDefiningOp()); 50 EXPECT_EQ(c10, subOp.getLhs()); 51 checkConstantValue(subOp.getRhs(), 1); 52 checkConstantValue(loop.getStep(), 1); 53 } 54 55 TEST_F(DoLoopHelperTest, createLoopWithLowerAndUpperBound) { 56 auto firBuilder = getBuilder(); 57 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc()); 58 59 auto lb = firBuilder.createIntegerConstant( 60 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1); 61 auto ub = firBuilder.createIntegerConstant( 62 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20); 63 auto loop = 64 helper.createLoop(lb, ub, [&](fir::FirOpBuilder &, mlir::Value index) {}); 65 checkConstantValue(loop.getLowerBound(), 1); 66 checkConstantValue(loop.getUpperBound(), 20); 67 checkConstantValue(loop.getStep(), 1); 68 } 69 70 TEST_F(DoLoopHelperTest, createLoopWithStep) { 71 auto firBuilder = getBuilder(); 72 fir::factory::DoLoopHelper helper(firBuilder, firBuilder.getUnknownLoc()); 73 74 auto lb = firBuilder.createIntegerConstant( 75 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 1); 76 auto ub = firBuilder.createIntegerConstant( 77 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 20); 78 auto step = firBuilder.createIntegerConstant( 79 firBuilder.getUnknownLoc(), firBuilder.getIndexType(), 2); 80 auto loop = helper.createLoop( 81 lb, ub, step, [&](fir::FirOpBuilder &, mlir::Value index) {}); 82 checkConstantValue(loop.getLowerBound(), 1); 83 checkConstantValue(loop.getUpperBound(), 20); 84 checkConstantValue(loop.getStep(), 2); 85 } 86