1b1fa5ac3SMircea Trofin //===- TensorSpecTest.cpp - test for TensorSpec ---------------------------===//
2b1fa5ac3SMircea Trofin //
3b1fa5ac3SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b1fa5ac3SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5b1fa5ac3SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b1fa5ac3SMircea Trofin //
7b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
8b1fa5ac3SMircea Trofin 
9b1fa5ac3SMircea Trofin #include "llvm/Analysis/TensorSpec.h"
10b1fa5ac3SMircea Trofin #include "llvm/Support/Path.h"
11b1fa5ac3SMircea Trofin #include "llvm/Support/SourceMgr.h"
12b1fa5ac3SMircea Trofin #include "llvm/Testing/Support/SupportHelpers.h"
13b1fa5ac3SMircea Trofin #include "gtest/gtest.h"
14b1fa5ac3SMircea Trofin 
15b1fa5ac3SMircea Trofin using namespace llvm;
16b1fa5ac3SMircea Trofin 
17b1fa5ac3SMircea Trofin extern const char *TestMainArgv0;
18b1fa5ac3SMircea Trofin 
TEST(TensorSpecTest,JSONParsing)19b1fa5ac3SMircea Trofin TEST(TensorSpecTest, JSONParsing) {
20b1fa5ac3SMircea Trofin   auto Value = json::parse(
21b1fa5ac3SMircea Trofin       R"({"name": "tensor_name",
22b1fa5ac3SMircea Trofin         "port": 2,
23b1fa5ac3SMircea Trofin         "type": "int32_t",
24b1fa5ac3SMircea Trofin         "shape":[1,4]
25b1fa5ac3SMircea Trofin         })");
26b1fa5ac3SMircea Trofin   EXPECT_TRUE(!!Value);
27b1fa5ac3SMircea Trofin   LLVMContext Ctx;
28b1fa5ac3SMircea Trofin   Optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value);
29*d152e50cSKazu Hirata   EXPECT_TRUE(Spec);
30b1fa5ac3SMircea Trofin   EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2));
31b1fa5ac3SMircea Trofin }
32b1fa5ac3SMircea Trofin 
TEST(TensorSpecTest,JSONParsingInvalidTensorType)33b1fa5ac3SMircea Trofin TEST(TensorSpecTest, JSONParsingInvalidTensorType) {
34b1fa5ac3SMircea Trofin   auto Value = json::parse(
35b1fa5ac3SMircea Trofin       R"(
36b1fa5ac3SMircea Trofin         {"name": "tensor_name",
37b1fa5ac3SMircea Trofin         "port": 2,
38b1fa5ac3SMircea Trofin         "type": "no such type",
39b1fa5ac3SMircea Trofin         "shape":[1,4]
40b1fa5ac3SMircea Trofin         }
41b1fa5ac3SMircea Trofin       )");
42b1fa5ac3SMircea Trofin   EXPECT_TRUE(!!Value);
43b1fa5ac3SMircea Trofin   LLVMContext Ctx;
44b1fa5ac3SMircea Trofin   auto Spec = getTensorSpecFromJSON(Ctx, *Value);
45*d152e50cSKazu Hirata   EXPECT_FALSE(Spec);
46b1fa5ac3SMircea Trofin }
47b1fa5ac3SMircea Trofin 
TEST(TensorSpecTest,TensorSpecSizesAndTypes)48b1fa5ac3SMircea Trofin TEST(TensorSpecTest, TensorSpecSizesAndTypes) {
49b1fa5ac3SMircea Trofin   auto Spec1D = TensorSpec::createSpec<int16_t>("Hi1", {1});
50b1fa5ac3SMircea Trofin   auto Spec2D = TensorSpec::createSpec<int16_t>("Hi2", {1, 1});
51b1fa5ac3SMircea Trofin   auto Spec1DLarge = TensorSpec::createSpec<float>("Hi3", {10});
52b1fa5ac3SMircea Trofin   auto Spec3DLarge = TensorSpec::createSpec<float>("Hi3", {2, 4, 10});
53b1fa5ac3SMircea Trofin   EXPECT_TRUE(Spec1D.isElementType<int16_t>());
54b1fa5ac3SMircea Trofin   EXPECT_FALSE(Spec3DLarge.isElementType<double>());
55b1fa5ac3SMircea Trofin   EXPECT_EQ(Spec1D.getElementCount(), 1U);
56b1fa5ac3SMircea Trofin   EXPECT_EQ(Spec2D.getElementCount(), 1U);
57b1fa5ac3SMircea Trofin   EXPECT_EQ(Spec1DLarge.getElementCount(), 10U);
58b1fa5ac3SMircea Trofin   EXPECT_EQ(Spec3DLarge.getElementCount(), 80U);
59b1fa5ac3SMircea Trofin   EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
60b1fa5ac3SMircea Trofin   EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
61b1fa5ac3SMircea Trofin }
62