1 //===-- flang/unittests/RuntimeGTest/Matmul.cpp---- -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Runtime/matmul.h" 10 #include "gtest/gtest.h" 11 #include "tools.h" 12 #include "flang/Runtime/allocatable.h" 13 #include "flang/Runtime/cpp-type.h" 14 #include "flang/Runtime/descriptor.h" 15 #include "flang/Runtime/type-code.h" 16 17 using namespace Fortran::runtime; 18 using Fortran::common::TypeCategory; 19 20 TEST(Matmul, Basic) { 21 // X 0 2 4 Y 6 9 V -1 -2 22 // 1 3 5 7 10 23 // 8 11 24 auto x{MakeArray<TypeCategory::Integer, 4>( 25 std::vector<int>{2, 3}, std::vector<std::int32_t>{0, 1, 2, 3, 4, 5})}; 26 auto y{MakeArray<TypeCategory::Integer, 2>( 27 std::vector<int>{3, 2}, std::vector<std::int16_t>{6, 7, 8, 9, 10, 11})}; 28 auto v{MakeArray<TypeCategory::Integer, 8>( 29 std::vector<int>{2}, std::vector<std::int64_t>{-1, -2})}; 30 StaticDescriptor<2, true> statDesc; 31 Descriptor &result{statDesc.descriptor()}; 32 33 RTNAME(Matmul)(result, *x, *y, __FILE__, __LINE__); 34 ASSERT_EQ(result.rank(), 2); 35 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); 36 EXPECT_EQ(result.GetDimension(0).Extent(), 2); 37 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); 38 EXPECT_EQ(result.GetDimension(1).Extent(), 2); 39 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 4})); 40 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46); 41 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67); 42 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64); 43 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94); 44 45 std::memset( 46 result.raw().base_addr, 0, result.Elements() * result.ElementBytes()); 47 result.GetDimension(0).SetLowerBound(0); 48 result.GetDimension(1).SetLowerBound(2); 49 RTNAME(MatmulDirect)(result, *x, *y, __FILE__, __LINE__); 50 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(0), 46); 51 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(1), 67); 52 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(2), 64); 53 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int32_t>(3), 94); 54 result.Destroy(); 55 56 RTNAME(Matmul)(result, *v, *x, __FILE__, __LINE__); 57 ASSERT_EQ(result.rank(), 1); 58 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); 59 EXPECT_EQ(result.GetDimension(0).Extent(), 3); 60 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); 61 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -2); 62 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -8); 63 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -14); 64 result.Destroy(); 65 66 RTNAME(Matmul)(result, *y, *v, __FILE__, __LINE__); 67 ASSERT_EQ(result.rank(), 1); 68 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); 69 EXPECT_EQ(result.GetDimension(0).Extent(), 3); 70 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Integer, 8})); 71 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(0), -24); 72 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(1), -27); 73 EXPECT_EQ(*result.ZeroBasedIndexedElement<std::int64_t>(2), -30); 74 result.Destroy(); 75 76 // X F F T Y F T 77 // F T T F T 78 // F F 79 auto xLog{MakeArray<TypeCategory::Logical, 1>(std::vector<int>{2, 3}, 80 std::vector<std::uint8_t>{false, false, false, true, true, false})}; 81 auto yLog{MakeArray<TypeCategory::Logical, 2>(std::vector<int>{3, 2}, 82 std::vector<std::uint16_t>{false, false, false, true, true, false})}; 83 RTNAME(Matmul)(result, *xLog, *yLog, __FILE__, __LINE__); 84 ASSERT_EQ(result.rank(), 2); 85 EXPECT_EQ(result.GetDimension(0).LowerBound(), 1); 86 EXPECT_EQ(result.GetDimension(0).Extent(), 2); 87 EXPECT_EQ(result.GetDimension(1).LowerBound(), 1); 88 EXPECT_EQ(result.GetDimension(1).Extent(), 2); 89 ASSERT_EQ(result.type(), (TypeCode{TypeCategory::Logical, 2})); 90 EXPECT_FALSE( 91 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(0))); 92 EXPECT_FALSE( 93 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(1))); 94 EXPECT_FALSE( 95 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(2))); 96 EXPECT_TRUE( 97 static_cast<bool>(*result.ZeroBasedIndexedElement<std::uint16_t>(3))); 98 } 99