1 //===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/AffineMap.h"
10 #include "mlir/IR/BuiltinAttributes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "gtest/gtest.h"
16 #include <cstdint>
17
18 using namespace mlir;
19 using namespace mlir::detail;
20
21 namespace {
TEST(ShapedTypeTest,CloneMemref)22 TEST(ShapedTypeTest, CloneMemref) {
23 MLIRContext context;
24
25 Type i32 = IntegerType::get(&context, 32);
26 Type f32 = FloatType::getF32(&context);
27 Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
28 Type memrefOriginalType = i32;
29 llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
30 AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
31
32 ShapedType memrefType =
33 (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
34 .setMemorySpace(memSpace)
35 .setLayout(AffineMapAttr::get(map));
36 // Update shape.
37 llvm::SmallVector<int64_t> memrefNewShape({30, 40});
38 ASSERT_NE(memrefOriginalShape, memrefNewShape);
39 ASSERT_EQ(memrefType.clone(memrefNewShape),
40 (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
41 .setMemorySpace(memSpace)
42 .setLayout(AffineMapAttr::get(map)));
43 // Update type.
44 Type memrefNewType = f32;
45 ASSERT_NE(memrefOriginalType, memrefNewType);
46 ASSERT_EQ(memrefType.clone(memrefNewType),
47 (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
48 .setMemorySpace(memSpace)
49 .setLayout(AffineMapAttr::get(map)));
50 // Update both.
51 ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
52 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
53 .setMemorySpace(memSpace)
54 .setLayout(AffineMapAttr::get(map)));
55
56 // Test unranked memref cloning.
57 ShapedType unrankedTensorType =
58 UnrankedMemRefType::get(memrefOriginalType, memSpace);
59 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
60 (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
61 .setMemorySpace(memSpace));
62 ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
63 UnrankedMemRefType::get(memrefNewType, memSpace));
64 ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
65 (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
66 .setMemorySpace(memSpace));
67 }
68
TEST(ShapedTypeTest,CloneTensor)69 TEST(ShapedTypeTest, CloneTensor) {
70 MLIRContext context;
71
72 Type i32 = IntegerType::get(&context, 32);
73 Type f32 = FloatType::getF32(&context);
74
75 Type tensorOriginalType = i32;
76 llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
77
78 // Test ranked tensor cloning.
79 ShapedType tensorType =
80 RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
81 // Update shape.
82 llvm::SmallVector<int64_t> tensorNewShape({30, 40});
83 ASSERT_NE(tensorOriginalShape, tensorNewShape);
84 ASSERT_EQ(
85 tensorType.clone(tensorNewShape),
86 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
87 // Update type.
88 Type tensorNewType = f32;
89 ASSERT_NE(tensorOriginalType, tensorNewType);
90 ASSERT_EQ(
91 tensorType.clone(tensorNewType),
92 (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
93 // Update both.
94 ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
95 (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
96
97 // Test unranked tensor cloning.
98 ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
99 ASSERT_EQ(
100 unrankedTensorType.clone(tensorNewShape),
101 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
102 ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
103 (ShapedType)UnrankedTensorType::get(tensorNewType));
104 ASSERT_EQ(
105 unrankedTensorType.clone(tensorNewShape),
106 (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
107 }
108
TEST(ShapedTypeTest,CloneVector)109 TEST(ShapedTypeTest, CloneVector) {
110 MLIRContext context;
111
112 Type i32 = IntegerType::get(&context, 32);
113 Type f32 = FloatType::getF32(&context);
114
115 Type vectorOriginalType = i32;
116 llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
117 ShapedType vectorType =
118 VectorType::get(vectorOriginalShape, vectorOriginalType);
119 // Update shape.
120 llvm::SmallVector<int64_t> vectorNewShape({30, 40});
121 ASSERT_NE(vectorOriginalShape, vectorNewShape);
122 ASSERT_EQ(vectorType.clone(vectorNewShape),
123 VectorType::get(vectorNewShape, vectorOriginalType));
124 // Update type.
125 Type vectorNewType = f32;
126 ASSERT_NE(vectorOriginalType, vectorNewType);
127 ASSERT_EQ(vectorType.clone(vectorNewType),
128 VectorType::get(vectorOriginalShape, vectorNewType));
129 // Update both.
130 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
131 VectorType::get(vectorNewShape, vectorNewType));
132 }
133
134 } // namespace
135