1d0cb0d30SAlexander Belyaev //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2d0cb0d30SAlexander Belyaev //
3d0cb0d30SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d0cb0d30SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5d0cb0d30SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d0cb0d30SAlexander Belyaev //
7d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
8d0cb0d30SAlexander Belyaev 
9d0cb0d30SAlexander Belyaev #include "mlir/Dialect/Complex/IR/Complex.h"
10a28fe17dSAdrian Kuegel #include "mlir/IR/Builders.h"
11036a6996Slewuathe #include "mlir/IR/Matchers.h"
12d0cb0d30SAlexander Belyaev 
13d0cb0d30SAlexander Belyaev using namespace mlir;
14d0cb0d30SAlexander Belyaev using namespace mlir::complex;
15d0cb0d30SAlexander Belyaev 
16d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
17480cd4cbSRiver Riddle // ConstantOp
18d0cb0d30SAlexander Belyaev //===----------------------------------------------------------------------===//
19d0cb0d30SAlexander Belyaev 
fold(ArrayRef<Attribute> operands)20480cd4cbSRiver Riddle OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
21480cd4cbSRiver Riddle   assert(operands.empty() && "constant has no operands");
22480cd4cbSRiver Riddle   return getValue();
23480cd4cbSRiver Riddle }
24480cd4cbSRiver Riddle 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)25480cd4cbSRiver Riddle void ConstantOp::getAsmResultNames(
26480cd4cbSRiver Riddle     function_ref<void(Value, StringRef)> setNameFn) {
27480cd4cbSRiver Riddle   setNameFn(getResult(), "cst");
28480cd4cbSRiver Riddle }
29480cd4cbSRiver Riddle 
isBuildableWith(Attribute value,Type type)30480cd4cbSRiver Riddle bool ConstantOp::isBuildableWith(Attribute value, Type type) {
31480cd4cbSRiver Riddle   if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
32480cd4cbSRiver Riddle     auto complexTy = type.dyn_cast<ComplexType>();
33480cd4cbSRiver Riddle     if (!complexTy)
34480cd4cbSRiver Riddle       return false;
35480cd4cbSRiver Riddle     auto complexEltTy = complexTy.getElementType();
36480cd4cbSRiver Riddle     return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
37480cd4cbSRiver Riddle            arrAttr[1].getType() == complexEltTy;
38480cd4cbSRiver Riddle   }
39480cd4cbSRiver Riddle   return false;
40480cd4cbSRiver Riddle }
41480cd4cbSRiver Riddle 
verify()421be88f5aSRiver Riddle LogicalResult ConstantOp::verify() {
431be88f5aSRiver Riddle   ArrayAttr arrayAttr = getValue();
44480cd4cbSRiver Riddle   if (arrayAttr.size() != 2) {
451be88f5aSRiver Riddle     return emitOpError(
46480cd4cbSRiver Riddle         "requires 'value' to be a complex constant, represented as array of "
47480cd4cbSRiver Riddle         "two values");
48480cd4cbSRiver Riddle   }
49480cd4cbSRiver Riddle 
501be88f5aSRiver Riddle   auto complexEltTy = getType().getElementType();
51480cd4cbSRiver Riddle   if (complexEltTy != arrayAttr[0].getType() ||
52480cd4cbSRiver Riddle       complexEltTy != arrayAttr[1].getType()) {
531be88f5aSRiver Riddle     return emitOpError()
54480cd4cbSRiver Riddle            << "requires attribute's element types (" << arrayAttr[0].getType()
55480cd4cbSRiver Riddle            << ", " << arrayAttr[1].getType()
56480cd4cbSRiver Riddle            << ") to match the element type of the op's return type ("
57480cd4cbSRiver Riddle            << complexEltTy << ")";
58480cd4cbSRiver Riddle   }
59480cd4cbSRiver Riddle   return success();
60480cd4cbSRiver Riddle }
61480cd4cbSRiver Riddle 
62480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
63480cd4cbSRiver Riddle // CreateOp
64480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
65fa765a09SAdrian Kuegel 
fold(ArrayRef<Attribute> operands)66dee46d08SAdrian Kuegel OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
67dee46d08SAdrian Kuegel   assert(operands.size() == 2 && "binary op takes two operands");
68dee46d08SAdrian Kuegel   // Fold complex.create(complex.re(op), complex.im(op)).
69dee46d08SAdrian Kuegel   if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
70dee46d08SAdrian Kuegel     if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
71dee46d08SAdrian Kuegel       if (reOp.getOperand() == imOp.getOperand()) {
72dee46d08SAdrian Kuegel         return reOp.getOperand();
73dee46d08SAdrian Kuegel       }
74dee46d08SAdrian Kuegel     }
75dee46d08SAdrian Kuegel   }
76fa765a09SAdrian Kuegel   return {};
77fa765a09SAdrian Kuegel }
78fa765a09SAdrian Kuegel 
79480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
80480cd4cbSRiver Riddle // ImOp
81480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
82480cd4cbSRiver Riddle 
fold(ArrayRef<Attribute> operands)83fa765a09SAdrian Kuegel OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
84fa765a09SAdrian Kuegel   assert(operands.size() == 1 && "unary op takes 1 operand");
85fa765a09SAdrian Kuegel   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
86fa765a09SAdrian Kuegel   if (arrayAttr && arrayAttr.size() == 2)
87fa765a09SAdrian Kuegel     return arrayAttr[1];
88cb65419bSAdrian Kuegel   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
89b99f892bSAdrian Kuegel     return createOp.getOperand(1);
90fa765a09SAdrian Kuegel   return {};
91fa765a09SAdrian Kuegel }
92dee46d08SAdrian Kuegel 
93480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
94480cd4cbSRiver Riddle // ReOp
95480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
96480cd4cbSRiver Riddle 
fold(ArrayRef<Attribute> operands)97dee46d08SAdrian Kuegel OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
98dee46d08SAdrian Kuegel   assert(operands.size() == 1 && "unary op takes 1 operand");
99dee46d08SAdrian Kuegel   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
100dee46d08SAdrian Kuegel   if (arrayAttr && arrayAttr.size() == 2)
101dee46d08SAdrian Kuegel     return arrayAttr[0];
102dee46d08SAdrian Kuegel   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
103dee46d08SAdrian Kuegel     return createOp.getOperand(0);
104dee46d08SAdrian Kuegel   return {};
105dee46d08SAdrian Kuegel }
106480cd4cbSRiver Riddle 
107480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
108036a6996Slewuathe // AddOp
109036a6996Slewuathe //===----------------------------------------------------------------------===//
110036a6996Slewuathe 
fold(ArrayRef<Attribute> operands)111036a6996Slewuathe OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
112036a6996Slewuathe   assert(operands.size() == 2 && "binary op takes 2 operands");
113036a6996Slewuathe 
114036a6996Slewuathe   // complex.add(complex.sub(a, b), b) -> a
115036a6996Slewuathe   if (auto sub = getLhs().getDefiningOp<SubOp>())
116036a6996Slewuathe     if (getRhs() == sub.getRhs())
117036a6996Slewuathe       return sub.getLhs();
118036a6996Slewuathe 
119036a6996Slewuathe   // complex.add(b, complex.sub(a, b)) -> a
120036a6996Slewuathe   if (auto sub = getRhs().getDefiningOp<SubOp>())
121036a6996Slewuathe     if (getLhs() == sub.getRhs())
122036a6996Slewuathe       return sub.getLhs();
123036a6996Slewuathe 
124036a6996Slewuathe   return {};
125036a6996Slewuathe }
126036a6996Slewuathe 
127036a6996Slewuathe //===----------------------------------------------------------------------===//
12801807095Slewuathe // NegOp
12901807095Slewuathe //===----------------------------------------------------------------------===//
13001807095Slewuathe 
fold(ArrayRef<Attribute> operands)13101807095Slewuathe OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
13201807095Slewuathe   assert(operands.size() == 1 && "unary op takes 1 operand");
13301807095Slewuathe 
13401807095Slewuathe   // complex.neg(complex.neg(a)) -> a
13501807095Slewuathe   if (auto negOp = getOperand().getDefiningOp<NegOp>())
13601807095Slewuathe     return negOp.getOperand();
13701807095Slewuathe 
13801807095Slewuathe   return {};
13901807095Slewuathe }
14001807095Slewuathe 
14101807095Slewuathe //===----------------------------------------------------------------------===//
142*5148c685Slewuathe // LogOp
143*5148c685Slewuathe //===----------------------------------------------------------------------===//
144*5148c685Slewuathe 
fold(ArrayRef<Attribute> operands)145*5148c685Slewuathe OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
146*5148c685Slewuathe   assert(operands.size() == 1 && "unary op takes 1 operand");
147*5148c685Slewuathe 
148*5148c685Slewuathe   // complex.log(complex.exp(a)) -> a
149*5148c685Slewuathe   if (auto expOp = getOperand().getDefiningOp<ExpOp>())
150*5148c685Slewuathe     return expOp.getOperand();
151*5148c685Slewuathe 
152*5148c685Slewuathe   return {};
153*5148c685Slewuathe }
154*5148c685Slewuathe 
155*5148c685Slewuathe //===----------------------------------------------------------------------===//
156*5148c685Slewuathe // ExpOp
157*5148c685Slewuathe //===----------------------------------------------------------------------===//
158*5148c685Slewuathe 
fold(ArrayRef<Attribute> operands)159*5148c685Slewuathe OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
160*5148c685Slewuathe   assert(operands.size() == 1 && "unary op takes 1 operand");
161*5148c685Slewuathe 
162*5148c685Slewuathe   // complex.exp(complex.log(a)) -> a
163*5148c685Slewuathe   if (auto logOp = getOperand().getDefiningOp<LogOp>())
164*5148c685Slewuathe     return logOp.getOperand();
165*5148c685Slewuathe 
166*5148c685Slewuathe   return {};
167*5148c685Slewuathe }
168*5148c685Slewuathe 
169*5148c685Slewuathe //===----------------------------------------------------------------------===//
170480cd4cbSRiver Riddle // TableGen'd op method definitions
171480cd4cbSRiver Riddle //===----------------------------------------------------------------------===//
172480cd4cbSRiver Riddle 
173480cd4cbSRiver Riddle #define GET_OP_CLASSES
174480cd4cbSRiver Riddle #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
175