109349303SJacques Pienaar //===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===//
209349303SJacques Pienaar //
309349303SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
409349303SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
509349303SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
609349303SJacques Pienaar //
709349303SJacques Pienaar //===----------------------------------------------------------------------===//
809349303SJacques Pienaar
909349303SJacques Pienaar #include "mlir/Interfaces/InferTypeOpInterface.h"
10a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1123aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1209349303SJacques Pienaar #include "mlir/IR/Builders.h"
1309349303SJacques Pienaar #include "mlir/IR/BuiltinOps.h"
1409349303SJacques Pienaar #include "mlir/IR/Dialect.h"
1509349303SJacques Pienaar #include "mlir/IR/DialectImplementation.h"
1609349303SJacques Pienaar #include "mlir/IR/ImplicitLocOpBuilder.h"
1709349303SJacques Pienaar #include "mlir/IR/OpDefinition.h"
1809349303SJacques Pienaar #include "mlir/IR/OpImplementation.h"
199eaff423SRiver Riddle #include "mlir/Parser/Parser.h"
2009349303SJacques Pienaar
2109349303SJacques Pienaar #include <gtest/gtest.h>
2209349303SJacques Pienaar
2309349303SJacques Pienaar using namespace mlir;
2409349303SJacques Pienaar
2509349303SJacques Pienaar class ValueShapeRangeTest : public testing::Test {
2609349303SJacques Pienaar protected:
SetUp()2709349303SJacques Pienaar void SetUp() override {
2809349303SJacques Pienaar const char *ir = R"MLIR(
29*63237cddSRiver Riddle func.func @map(%arg : tensor<1xi64>) {
30a54f4eaeSMogball %0 = arith.constant dense<[10]> : tensor<1xi64>
31a54f4eaeSMogball %1 = arith.addi %arg, %0 : tensor<1xi64>
3209349303SJacques Pienaar return
3309349303SJacques Pienaar }
3409349303SJacques Pienaar )MLIR";
3509349303SJacques Pienaar
3623aa5a74SRiver Riddle registry.insert<func::FuncDialect, arith::ArithmeticDialect>();
3709349303SJacques Pienaar ctx.appendDialectRegistry(registry);
38dfaadf6bSChristian Sigg module = parseSourceString<ModuleOp>(ir, &ctx);
3958ceae95SRiver Riddle mapFn = cast<func::FuncOp>(module->front());
4009349303SJacques Pienaar }
4109349303SJacques Pienaar
42a54f4eaeSMogball // Create ValueShapeRange on the arith.addi operation.
addiRange()4309349303SJacques Pienaar ValueShapeRange addiRange() {
44f8d5c73cSRiver Riddle auto &fnBody = mapFn.getBody();
4509349303SJacques Pienaar return std::next(fnBody.front().begin())->getOperands();
4609349303SJacques Pienaar }
4709349303SJacques Pienaar
4809349303SJacques Pienaar DialectRegistry registry;
4909349303SJacques Pienaar MLIRContext ctx;
508f66ab1cSSanjoy Das OwningOpRef<ModuleOp> module;
5158ceae95SRiver Riddle func::FuncOp mapFn;
5209349303SJacques Pienaar };
5309349303SJacques Pienaar
TEST_F(ValueShapeRangeTest,ShapesFromValues)5409349303SJacques Pienaar TEST_F(ValueShapeRangeTest, ShapesFromValues) {
5509349303SJacques Pienaar ValueShapeRange range = addiRange();
5609349303SJacques Pienaar
5709349303SJacques Pienaar EXPECT_FALSE(range.getValueAsShape(0));
5809349303SJacques Pienaar ASSERT_TRUE(range.getValueAsShape(1));
5909349303SJacques Pienaar EXPECT_TRUE(range.getValueAsShape(1).hasRank());
6009349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
6109349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
6209349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getRank(), 1);
6309349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
6409349303SJacques Pienaar }
6509349303SJacques Pienaar
TEST_F(ValueShapeRangeTest,MapValuesToShapes)6609349303SJacques Pienaar TEST_F(ValueShapeRangeTest, MapValuesToShapes) {
6709349303SJacques Pienaar ValueShapeRange range = addiRange();
6809349303SJacques Pienaar ShapedTypeComponents fixed(SmallVector<int64_t>{30});
6909349303SJacques Pienaar auto mapping = [&](Value val) -> ShapeAdaptor {
7009349303SJacques Pienaar if (val == mapFn.getArgument(0))
7109349303SJacques Pienaar return &fixed;
7209349303SJacques Pienaar return nullptr;
7309349303SJacques Pienaar };
7409349303SJacques Pienaar range.setValueToShapeMapping(mapping);
7509349303SJacques Pienaar
7609349303SJacques Pienaar ASSERT_TRUE(range.getValueAsShape(0));
7709349303SJacques Pienaar EXPECT_TRUE(range.getValueAsShape(0).hasRank());
7809349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(0).getRank(), 1);
7909349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30);
8009349303SJacques Pienaar ASSERT_TRUE(range.getValueAsShape(1));
8109349303SJacques Pienaar EXPECT_TRUE(range.getValueAsShape(1).hasRank());
8209349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
8309349303SJacques Pienaar EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
8409349303SJacques Pienaar }
8509349303SJacques Pienaar
TEST_F(ValueShapeRangeTest,SettingShapes)8609349303SJacques Pienaar TEST_F(ValueShapeRangeTest, SettingShapes) {
8709349303SJacques Pienaar ShapedTypeComponents shape(SmallVector<int64_t>{10, 20});
8809349303SJacques Pienaar ValueShapeRange range = addiRange();
8909349303SJacques Pienaar auto mapping = [&](Value val) -> ShapeAdaptor {
9009349303SJacques Pienaar if (val == mapFn.getArgument(0))
9109349303SJacques Pienaar return &shape;
9209349303SJacques Pienaar return nullptr;
9309349303SJacques Pienaar };
9409349303SJacques Pienaar range.setOperandShapeMapping(mapping);
9509349303SJacques Pienaar
9609349303SJacques Pienaar ASSERT_TRUE(range.getShape(0));
9709349303SJacques Pienaar EXPECT_EQ(range.getShape(0).getRank(), 2);
9809349303SJacques Pienaar EXPECT_EQ(range.getShape(0).getDimSize(0), 10);
9909349303SJacques Pienaar EXPECT_EQ(range.getShape(0).getDimSize(1), 20);
10009349303SJacques Pienaar ASSERT_TRUE(range.getShape(1));
10109349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getRank(), 1);
10209349303SJacques Pienaar EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
10309349303SJacques Pienaar EXPECT_FALSE(range.getShape(2));
10409349303SJacques Pienaar }
105