1 //===- InferShapeTest.cpp - unit tests for shape inference ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/IR/AffineMap.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
14
15 using namespace mlir;
16 using namespace mlir::memref;
17
18 // Source memref has identity layout.
TEST(InferShapeTest,inferRankReducedShapeIdentity)19 TEST(InferShapeTest, inferRankReducedShapeIdentity) {
20 MLIRContext ctx;
21 OpBuilder b(&ctx);
22 auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType());
23 auto reducedType = SubViewOp::inferRankReducedResultType(
24 /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
25 AffineExpr dim0;
26 bindDims(&ctx, dim0);
27 auto expectedType =
28 MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 13));
29 EXPECT_EQ(reducedType, expectedType);
30 }
31
32 // Source memref has non-identity layout.
TEST(InferShapeTest,inferRankReducedShapeNonIdentity)33 TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
34 MLIRContext ctx;
35 OpBuilder b(&ctx);
36 AffineExpr dim0, dim1;
37 bindDims(&ctx, dim0, dim1);
38 auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
39 AffineMap::get(2, 0, 1000 * dim0 + dim1));
40 auto reducedType = SubViewOp::inferRankReducedResultType(
41 /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
42 auto expectedType =
43 MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003));
44 EXPECT_EQ(reducedType, expectedType);
45 }
46
TEST(InferShapeTest,inferRankReducedShapeToScalar)47 TEST(InferShapeTest, inferRankReducedShapeToScalar) {
48 MLIRContext ctx;
49 OpBuilder b(&ctx);
50 AffineExpr dim0, dim1;
51 bindDims(&ctx, dim0, dim1);
52 auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
53 AffineMap::get(2, 0, 1000 * dim0 + dim1));
54 auto reducedType = SubViewOp::inferRankReducedResultType(
55 /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1});
56 auto expectedType =
57 MemRefType::get({}, b.getIndexType(),
58 AffineMap::get(0, 0, b.getAffineConstantExpr(2003)));
59 EXPECT_EQ(reducedType, expectedType);
60 }
61