1d0cb0d30SAlexander Belyaev //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2d0cb0d30SAlexander Belyaev //
3d0cb0d30SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d0cb0d30SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5d0cb0d30SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d0cb0d30SAlexander Belyaev //
7d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
8d0cb0d30SAlexander Belyaev
9d0cb0d30SAlexander Belyaev #include "mlir/Dialect/Complex/IR/Complex.h"
10a28fe17dSAdrian Kuegel #include "mlir/IR/Builders.h"
11036a6996Slewuathe #include "mlir/IR/Matchers.h"
12d0cb0d30SAlexander Belyaev
13d0cb0d30SAlexander Belyaev using namespace mlir;
14d0cb0d30SAlexander Belyaev using namespace mlir::complex;
15d0cb0d30SAlexander Belyaev
16d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
17480cd4cbSRiver Riddle // ConstantOp
18d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
19d0cb0d30SAlexander Belyaev
fold(ArrayRef<Attribute> operands)20480cd4cbSRiver Riddle OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
21480cd4cbSRiver Riddle assert(operands.empty() && "constant has no operands");
22480cd4cbSRiver Riddle return getValue();
23480cd4cbSRiver Riddle }
24480cd4cbSRiver Riddle
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)25480cd4cbSRiver Riddle void ConstantOp::getAsmResultNames(
26480cd4cbSRiver Riddle function_ref<void(Value, StringRef)> setNameFn) {
27480cd4cbSRiver Riddle setNameFn(getResult(), "cst");
28480cd4cbSRiver Riddle }
29480cd4cbSRiver Riddle
isBuildableWith(Attribute value,Type type)30480cd4cbSRiver Riddle bool ConstantOp::isBuildableWith(Attribute value, Type type) {
31480cd4cbSRiver Riddle if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
32480cd4cbSRiver Riddle auto complexTy = type.dyn_cast<ComplexType>();
33480cd4cbSRiver Riddle if (!complexTy)
34480cd4cbSRiver Riddle return false;
35480cd4cbSRiver Riddle auto complexEltTy = complexTy.getElementType();
36480cd4cbSRiver Riddle return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
37480cd4cbSRiver Riddle arrAttr[1].getType() == complexEltTy;
38480cd4cbSRiver Riddle }
39480cd4cbSRiver Riddle return false;
40480cd4cbSRiver Riddle }
41480cd4cbSRiver Riddle
verify()421be88f5aSRiver Riddle LogicalResult ConstantOp::verify() {
431be88f5aSRiver Riddle ArrayAttr arrayAttr = getValue();
44480cd4cbSRiver Riddle if (arrayAttr.size() != 2) {
451be88f5aSRiver Riddle return emitOpError(
46480cd4cbSRiver Riddle "requires 'value' to be a complex constant, represented as array of "
47480cd4cbSRiver Riddle "two values");
48480cd4cbSRiver Riddle }
49480cd4cbSRiver Riddle
501be88f5aSRiver Riddle auto complexEltTy = getType().getElementType();
51480cd4cbSRiver Riddle if (complexEltTy != arrayAttr[0].getType() ||
52480cd4cbSRiver Riddle complexEltTy != arrayAttr[1].getType()) {
531be88f5aSRiver Riddle return emitOpError()
54480cd4cbSRiver Riddle << "requires attribute's element types (" << arrayAttr[0].getType()
55480cd4cbSRiver Riddle << ", " << arrayAttr[1].getType()
56480cd4cbSRiver Riddle << ") to match the element type of the op's return type ("
57480cd4cbSRiver Riddle << complexEltTy << ")";
58480cd4cbSRiver Riddle }
59480cd4cbSRiver Riddle return success();
60480cd4cbSRiver Riddle }
61480cd4cbSRiver Riddle
62480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
63480cd4cbSRiver Riddle // CreateOp
64480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
65fa765a09SAdrian Kuegel
fold(ArrayRef<Attribute> operands)66dee46d08SAdrian Kuegel OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
67dee46d08SAdrian Kuegel assert(operands.size() == 2 && "binary op takes two operands");
68dee46d08SAdrian Kuegel // Fold complex.create(complex.re(op), complex.im(op)).
69dee46d08SAdrian Kuegel if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
70dee46d08SAdrian Kuegel if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
71dee46d08SAdrian Kuegel if (reOp.getOperand() == imOp.getOperand()) {
72dee46d08SAdrian Kuegel return reOp.getOperand();
73dee46d08SAdrian Kuegel }
74dee46d08SAdrian Kuegel }
75dee46d08SAdrian Kuegel }
76fa765a09SAdrian Kuegel return {};
77fa765a09SAdrian Kuegel }
78fa765a09SAdrian Kuegel
79480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
80480cd4cbSRiver Riddle // ImOp
81480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
82480cd4cbSRiver Riddle
fold(ArrayRef<Attribute> operands)83fa765a09SAdrian Kuegel OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
84fa765a09SAdrian Kuegel assert(operands.size() == 1 && "unary op takes 1 operand");
85fa765a09SAdrian Kuegel ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
86fa765a09SAdrian Kuegel if (arrayAttr && arrayAttr.size() == 2)
87fa765a09SAdrian Kuegel return arrayAttr[1];
88cb65419bSAdrian Kuegel if (auto createOp = getOperand().getDefiningOp<CreateOp>())
89b99f892bSAdrian Kuegel return createOp.getOperand(1);
90fa765a09SAdrian Kuegel return {};
91fa765a09SAdrian Kuegel }
92dee46d08SAdrian Kuegel
93480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
94480cd4cbSRiver Riddle // ReOp
95480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
96480cd4cbSRiver Riddle
fold(ArrayRef<Attribute> operands)97dee46d08SAdrian Kuegel OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
98dee46d08SAdrian Kuegel assert(operands.size() == 1 && "unary op takes 1 operand");
99dee46d08SAdrian Kuegel ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
100dee46d08SAdrian Kuegel if (arrayAttr && arrayAttr.size() == 2)
101dee46d08SAdrian Kuegel return arrayAttr[0];
102dee46d08SAdrian Kuegel if (auto createOp = getOperand().getDefiningOp<CreateOp>())
103dee46d08SAdrian Kuegel return createOp.getOperand(0);
104dee46d08SAdrian Kuegel return {};
105dee46d08SAdrian Kuegel }
106480cd4cbSRiver Riddle
107480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
108036a6996Slewuathe // AddOp
109036a6996Slewuathe //===----------------------------------------------------------------------===//
110036a6996Slewuathe
fold(ArrayRef<Attribute> operands)111036a6996Slewuathe OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
112036a6996Slewuathe assert(operands.size() == 2 && "binary op takes 2 operands");
113036a6996Slewuathe
114036a6996Slewuathe // complex.add(complex.sub(a, b), b) -> a
115036a6996Slewuathe if (auto sub = getLhs().getDefiningOp<SubOp>())
116036a6996Slewuathe if (getRhs() == sub.getRhs())
117036a6996Slewuathe return sub.getLhs();
118036a6996Slewuathe
119036a6996Slewuathe // complex.add(b, complex.sub(a, b)) -> a
120036a6996Slewuathe if (auto sub = getRhs().getDefiningOp<SubOp>())
121036a6996Slewuathe if (getLhs() == sub.getRhs())
122036a6996Slewuathe return sub.getLhs();
123036a6996Slewuathe
124036a6996Slewuathe return {};
125036a6996Slewuathe }
126036a6996Slewuathe
127036a6996Slewuathe //===----------------------------------------------------------------------===//
12801807095Slewuathe // NegOp
12901807095Slewuathe //===----------------------------------------------------------------------===//
13001807095Slewuathe
fold(ArrayRef<Attribute> operands)13101807095Slewuathe OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
13201807095Slewuathe assert(operands.size() == 1 && "unary op takes 1 operand");
13301807095Slewuathe
13401807095Slewuathe // complex.neg(complex.neg(a)) -> a
13501807095Slewuathe if (auto negOp = getOperand().getDefiningOp<NegOp>())
13601807095Slewuathe return negOp.getOperand();
13701807095Slewuathe
13801807095Slewuathe return {};
13901807095Slewuathe }
14001807095Slewuathe
14101807095Slewuathe //===----------------------------------------------------------------------===//
142*5148c685Slewuathe // LogOp
143*5148c685Slewuathe //===----------------------------------------------------------------------===//
144*5148c685Slewuathe
fold(ArrayRef<Attribute> operands)145*5148c685Slewuathe OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
146*5148c685Slewuathe assert(operands.size() == 1 && "unary op takes 1 operand");
147*5148c685Slewuathe
148*5148c685Slewuathe // complex.log(complex.exp(a)) -> a
149*5148c685Slewuathe if (auto expOp = getOperand().getDefiningOp<ExpOp>())
150*5148c685Slewuathe return expOp.getOperand();
151*5148c685Slewuathe
152*5148c685Slewuathe return {};
153*5148c685Slewuathe }
154*5148c685Slewuathe
155*5148c685Slewuathe //===----------------------------------------------------------------------===//
156*5148c685Slewuathe // ExpOp
157*5148c685Slewuathe //===----------------------------------------------------------------------===//
158*5148c685Slewuathe
fold(ArrayRef<Attribute> operands)159*5148c685Slewuathe OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
160*5148c685Slewuathe assert(operands.size() == 1 && "unary op takes 1 operand");
161*5148c685Slewuathe
162*5148c685Slewuathe // complex.exp(complex.log(a)) -> a
163*5148c685Slewuathe if (auto logOp = getOperand().getDefiningOp<LogOp>())
164*5148c685Slewuathe return logOp.getOperand();
165*5148c685Slewuathe
166*5148c685Slewuathe return {};
167*5148c685Slewuathe }
168*5148c685Slewuathe
169*5148c685Slewuathe //===----------------------------------------------------------------------===//
170480cd4cbSRiver Riddle // TableGen'd op method definitions
171480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
172480cd4cbSRiver Riddle
173480cd4cbSRiver Riddle #define GET_OP_CLASSES
174480cd4cbSRiver Riddle #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
175