1 //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Interfaces/InferTypeOpInterface.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/IR/Dialect.h" 14 #include "mlir/IR/DialectImplementation.h" 15 #include "mlir/IR/ImplicitLocOpBuilder.h" 16 #include "mlir/IR/OpDefinition.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/Parser.h" 19 20 #include <gtest/gtest.h> 21 22 using namespace mlir; 23 24 class ValueShapeRangeTest : public testing::Test { 25 protected: 26 void SetUp() override { 27 const char *ir = R"MLIR( 28 func @map(%arg : tensor<1xi64>) { 29 %0 = constant dense<[10]> : tensor<1xi64> 30 %1 = addi %arg, %0 : tensor<1xi64> 31 return 32 } 33 )MLIR"; 34 35 registry.insert<StandardOpsDialect>(); 36 ctx.appendDialectRegistry(registry); 37 module = parseSourceString(ir, &ctx); 38 mapFn = cast<FuncOp>(module->front()); 39 } 40 41 // Create ValueShapeRange on the addi operation. 42 ValueShapeRange addiRange() { 43 auto &fnBody = mapFn.body(); 44 return std::next(fnBody.front().begin())->getOperands(); 45 } 46 47 DialectRegistry registry; 48 MLIRContext ctx; 49 OwningModuleRef module; 50 FuncOp mapFn; 51 }; 52 53 TEST_F(ValueShapeRangeTest, ShapesFromValues) { 54 ValueShapeRange range = addiRange(); 55 56 EXPECT_FALSE(range.getValueAsShape(0)); 57 ASSERT_TRUE(range.getValueAsShape(1)); 58 EXPECT_TRUE(range.getValueAsShape(1).hasRank()); 59 EXPECT_EQ(range.getValueAsShape(1).getRank(), 1); 60 EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); 61 EXPECT_EQ(range.getShape(1).getRank(), 1); 62 EXPECT_EQ(range.getShape(1).getDimSize(0), 1); 63 } 64 65 TEST_F(ValueShapeRangeTest, MapValuesToShapes) { 66 ValueShapeRange range = addiRange(); 67 ShapedTypeComponents fixed(SmallVector<int64_t>{30}); 68 auto mapping = [&](Value val) -> ShapeAdaptor { 69 if (val == mapFn.getArgument(0)) 70 return &fixed; 71 return nullptr; 72 }; 73 range.setValueToShapeMapping(mapping); 74 75 ASSERT_TRUE(range.getValueAsShape(0)); 76 EXPECT_TRUE(range.getValueAsShape(0).hasRank()); 77 EXPECT_EQ(range.getValueAsShape(0).getRank(), 1); 78 EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30); 79 ASSERT_TRUE(range.getValueAsShape(1)); 80 EXPECT_TRUE(range.getValueAsShape(1).hasRank()); 81 EXPECT_EQ(range.getValueAsShape(1).getRank(), 1); 82 EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10); 83 } 84 85 TEST_F(ValueShapeRangeTest, SettingShapes) { 86 ShapedTypeComponents shape(SmallVector<int64_t>{10, 20}); 87 ValueShapeRange range = addiRange(); 88 auto mapping = [&](Value val) -> ShapeAdaptor { 89 if (val == mapFn.getArgument(0)) 90 return &shape; 91 return nullptr; 92 }; 93 range.setOperandShapeMapping(mapping); 94 95 ASSERT_TRUE(range.getShape(0)); 96 EXPECT_EQ(range.getShape(0).getRank(), 2); 97 EXPECT_EQ(range.getShape(0).getDimSize(0), 10); 98 EXPECT_EQ(range.getShape(0).getDimSize(1), 20); 99 ASSERT_TRUE(range.getShape(1)); 100 EXPECT_EQ(range.getShape(1).getRank(), 1); 101 EXPECT_EQ(range.getShape(1).getDimSize(0), 1); 102 EXPECT_FALSE(range.getShape(2)); 103 } 104