1 //===- InferShapeTest.cpp - unit tests for shape inference ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/IR/AffineMap.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "gtest/gtest.h" 14 15 using namespace mlir; 16 using namespace mlir::memref; 17 18 // Source memref has identity layout. 19 TEST(InferShapeTest, inferRankReducedShapeIdentity) { 20 MLIRContext ctx; 21 OpBuilder b(&ctx); 22 auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType()); 23 auto reducedType = SubViewOp::inferRankReducedResultType( 24 /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); 25 AffineExpr dim0; 26 bindDims(&ctx, dim0); 27 auto expectedType = 28 MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 13)); 29 EXPECT_EQ(reducedType, expectedType); 30 } 31 32 // Source memref has non-identity layout. 33 TEST(InferShapeTest, inferRankReducedShapeNonIdentity) { 34 MLIRContext ctx; 35 OpBuilder b(&ctx); 36 AffineExpr dim0, dim1; 37 bindDims(&ctx, dim0, dim1); 38 auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), 39 AffineMap::get(2, 0, 1000 * dim0 + dim1)); 40 auto reducedType = SubViewOp::inferRankReducedResultType( 41 /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); 42 auto expectedType = 43 MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003)); 44 EXPECT_EQ(reducedType, expectedType); 45 } 46 47 TEST(InferShapeTest, inferRankReducedShapeToScalar) { 48 MLIRContext ctx; 49 OpBuilder b(&ctx); 50 AffineExpr dim0, dim1; 51 bindDims(&ctx, dim0, dim1); 52 auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), 53 AffineMap::get(2, 0, 1000 * dim0 + dim1)); 54 auto reducedType = SubViewOp::inferRankReducedResultType( 55 /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1}); 56 auto expectedType = 57 MemRefType::get({}, b.getIndexType(), 58 AffineMap::get(0, 0, b.getAffineConstantExpr(2003))); 59 EXPECT_EQ(reducedType, expectedType); 60 } 61