109349303SJacques Pienaar //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===//
209349303SJacques Pienaar //
309349303SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409349303SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
509349303SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609349303SJacques Pienaar //
709349303SJacques Pienaar //===----------------------------------------------------------------------===//
809349303SJacques Pienaar 
909349303SJacques Pienaar #include "mlir/Interfaces/InferTypeOpInterface.h"
10a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1123aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1209349303SJacques Pienaar #include "mlir/IR/Builders.h"
1309349303SJacques Pienaar #include "mlir/IR/BuiltinOps.h"
1409349303SJacques Pienaar #include "mlir/IR/Dialect.h"
1509349303SJacques Pienaar #include "mlir/IR/DialectImplementation.h"
1609349303SJacques Pienaar #include "mlir/IR/ImplicitLocOpBuilder.h"
1709349303SJacques Pienaar #include "mlir/IR/OpDefinition.h"
1809349303SJacques Pienaar #include "mlir/IR/OpImplementation.h"
199eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
2009349303SJacques Pienaar 
2109349303SJacques Pienaar #include <gtest/gtest.h>
2209349303SJacques Pienaar 
2309349303SJacques Pienaar using namespace mlir;
2409349303SJacques Pienaar 
2509349303SJacques Pienaar class ValueShapeRangeTest : public testing::Test {
2609349303SJacques Pienaar protected:
SetUp()2709349303SJacques Pienaar   void SetUp() override {
2809349303SJacques Pienaar     const char *ir = R"MLIR(
29*63237cddSRiver Riddle       func.func @map(%arg : tensor<1xi64>) {
30a54f4eaeSMogball         %0 = arith.constant dense<[10]> : tensor<1xi64>
31a54f4eaeSMogball         %1 = arith.addi %arg, %0 : tensor<1xi64>
3209349303SJacques Pienaar         return
3309349303SJacques Pienaar       }
3409349303SJacques Pienaar     )MLIR";
3509349303SJacques Pienaar 
3623aa5a74SRiver Riddle     registry.insert<func::FuncDialect, arith::ArithmeticDialect>();
3709349303SJacques Pienaar     ctx.appendDialectRegistry(registry);
38dfaadf6bSChristian Sigg     module = parseSourceString<ModuleOp>(ir, &ctx);
3958ceae95SRiver Riddle     mapFn = cast<func::FuncOp>(module->front());
4009349303SJacques Pienaar   }
4109349303SJacques Pienaar 
42a54f4eaeSMogball   // Create ValueShapeRange on the arith.addi operation.
addiRange()4309349303SJacques Pienaar   ValueShapeRange addiRange() {
44f8d5c73cSRiver Riddle     auto &fnBody = mapFn.getBody();
4509349303SJacques Pienaar     return std::next(fnBody.front().begin())->getOperands();
4609349303SJacques Pienaar   }
4709349303SJacques Pienaar 
4809349303SJacques Pienaar   DialectRegistry registry;
4909349303SJacques Pienaar   MLIRContext ctx;
508f66ab1cSSanjoy Das   OwningOpRef<ModuleOp> module;
5158ceae95SRiver Riddle   func::FuncOp mapFn;
5209349303SJacques Pienaar };
5309349303SJacques Pienaar 
TEST_F(ValueShapeRangeTest,ShapesFromValues)5409349303SJacques Pienaar TEST_F(ValueShapeRangeTest, ShapesFromValues) {
5509349303SJacques Pienaar   ValueShapeRange range = addiRange();
5609349303SJacques Pienaar 
5709349303SJacques Pienaar   EXPECT_FALSE(range.getValueAsShape(0));
5809349303SJacques Pienaar   ASSERT_TRUE(range.getValueAsShape(1));
5909349303SJacques Pienaar   EXPECT_TRUE(range.getValueAsShape(1).hasRank());
6009349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
6109349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
6209349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getRank(), 1);
6309349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
6409349303SJacques Pienaar }
6509349303SJacques Pienaar 
TEST_F(ValueShapeRangeTest,MapValuesToShapes)6609349303SJacques Pienaar TEST_F(ValueShapeRangeTest, MapValuesToShapes) {
6709349303SJacques Pienaar   ValueShapeRange range = addiRange();
6809349303SJacques Pienaar   ShapedTypeComponents fixed(SmallVector<int64_t>{30});
6909349303SJacques Pienaar   auto mapping = [&](Value val) -> ShapeAdaptor {
7009349303SJacques Pienaar     if (val == mapFn.getArgument(0))
7109349303SJacques Pienaar       return &fixed;
7209349303SJacques Pienaar     return nullptr;
7309349303SJacques Pienaar   };
7409349303SJacques Pienaar   range.setValueToShapeMapping(mapping);
7509349303SJacques Pienaar 
7609349303SJacques Pienaar   ASSERT_TRUE(range.getValueAsShape(0));
7709349303SJacques Pienaar   EXPECT_TRUE(range.getValueAsShape(0).hasRank());
7809349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(0).getRank(), 1);
7909349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30);
8009349303SJacques Pienaar   ASSERT_TRUE(range.getValueAsShape(1));
8109349303SJacques Pienaar   EXPECT_TRUE(range.getValueAsShape(1).hasRank());
8209349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
8309349303SJacques Pienaar   EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
8409349303SJacques Pienaar }
8509349303SJacques Pienaar 
TEST_F(ValueShapeRangeTest,SettingShapes)8609349303SJacques Pienaar TEST_F(ValueShapeRangeTest, SettingShapes) {
8709349303SJacques Pienaar   ShapedTypeComponents shape(SmallVector<int64_t>{10, 20});
8809349303SJacques Pienaar   ValueShapeRange range = addiRange();
8909349303SJacques Pienaar   auto mapping = [&](Value val) -> ShapeAdaptor {
9009349303SJacques Pienaar     if (val == mapFn.getArgument(0))
9109349303SJacques Pienaar       return &shape;
9209349303SJacques Pienaar     return nullptr;
9309349303SJacques Pienaar   };
9409349303SJacques Pienaar   range.setOperandShapeMapping(mapping);
9509349303SJacques Pienaar 
9609349303SJacques Pienaar   ASSERT_TRUE(range.getShape(0));
9709349303SJacques Pienaar   EXPECT_EQ(range.getShape(0).getRank(), 2);
9809349303SJacques Pienaar   EXPECT_EQ(range.getShape(0).getDimSize(0), 10);
9909349303SJacques Pienaar   EXPECT_EQ(range.getShape(0).getDimSize(1), 20);
10009349303SJacques Pienaar   ASSERT_TRUE(range.getShape(1));
10109349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getRank(), 1);
10209349303SJacques Pienaar   EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
10309349303SJacques Pienaar   EXPECT_FALSE(range.getShape(2));
10409349303SJacques Pienaar }
105