1 //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/InferTypeOpInterface.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/DialectImplementation.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/Parser.h" 20 21 #include <gtest/gtest.h> 22 23 using namespace mlir; 24 25 class ValueShapeRangeTest : public testing::Test { 26 protected: 27 void SetUp() override { 28 const char *ir = R"MLIR( 29 func @map(%arg : tensor<1xi64>) { 30 %0 = arith.constant dense<[10]> : tensor<1xi64> 31 %1 = arith.addi %arg, %0 : tensor<1xi64> 32 return 33 } 34 )MLIR"; 35 36 registry.insert<StandardOpsDialect, arith::ArithmeticDialect>(); 37 ctx.appendDialectRegistry(registry); 38 module = parseSourceString(ir, &ctx); 39 mapFn = cast<FuncOp>(module->front()); 40 } 41 42 // Create ValueShapeRange on the arith.addi operation. 43 ValueShapeRange addiRange() { 44 auto &fnBody = mapFn.body(); 45 return std::next(fnBody.front().begin())->getOperands(); 46 } 47 48 DialectRegistry registry; 49 MLIRContext ctx; 50 OwningModuleRef module; 51 FuncOp mapFn; 52 }; 53 54 TEST_F(ValueShapeRangeTest, ShapesFromValues) { 55 ValueShapeRange range = addiRange(); 56 57 EXPECT_FALSE(range.getValueAsShape(0)); 58 ASSERT_TRUE(range.getValueAsShape(1)); 59 EXPECT_TRUE(range.getValueAsShape(1).hasRank()); 60 EXPECT_EQ(range.getValueAsShape(1).getRank(), 1); 61 EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); 62 EXPECT_EQ(range.getShape(1).getRank(), 1); 63 EXPECT_EQ(range.getShape(1).getDimSize(0), 1); 64 } 65 66 TEST_F(ValueShapeRangeTest, MapValuesToShapes) { 67 ValueShapeRange range = addiRange(); 68 ShapedTypeComponents fixed(SmallVector<int64_t>{30}); 69 auto mapping = [&](Value val) -> ShapeAdaptor { 70 if (val == mapFn.getArgument(0)) 71 return &fixed; 72 return nullptr; 73 }; 74 range.setValueToShapeMapping(mapping); 75 76 ASSERT_TRUE(range.getValueAsShape(0)); 77 EXPECT_TRUE(range.getValueAsShape(0).hasRank()); 78 EXPECT_EQ(range.getValueAsShape(0).getRank(), 1); 79 EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30); 80 ASSERT_TRUE(range.getValueAsShape(1)); 81 EXPECT_TRUE(range.getValueAsShape(1).hasRank()); 82 EXPECT_EQ(range.getValueAsShape(1).getRank(), 1); 83 EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); 84 } 85 86 TEST_F(ValueShapeRangeTest, SettingShapes) { 87 ShapedTypeComponents shape(SmallVector<int64_t>{10, 20}); 88 ValueShapeRange range = addiRange(); 89 auto mapping = [&](Value val) -> ShapeAdaptor { 90 if (val == mapFn.getArgument(0)) 91 return &shape; 92 return nullptr; 93 }; 94 range.setOperandShapeMapping(mapping); 95 96 ASSERT_TRUE(range.getShape(0)); 97 EXPECT_EQ(range.getShape(0).getRank(), 2); 98 EXPECT_EQ(range.getShape(0).getDimSize(0), 10); 99 EXPECT_EQ(range.getShape(0).getDimSize(1), 20); 100 ASSERT_TRUE(range.getShape(1)); 101 EXPECT_EQ(range.getShape(1).getRank(), 1); 102 EXPECT_EQ(range.getShape(1).getDimSize(0), 1); 103 EXPECT_FALSE(range.getShape(2)); 104 } 105