1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Complex/IR/Complex.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/Matchers.h"
12 
13 using namespace mlir;
14 using namespace mlir::complex;
15 
16 //===----------------------------------------------------------------------===//
17 // ConstantOp
18 //===----------------------------------------------------------------------===//
19 
fold(ArrayRef<Attribute> operands)20 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
21   assert(operands.empty() && "constant has no operands");
22   return getValue();
23 }
24 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)25 void ConstantOp::getAsmResultNames(
26     function_ref<void(Value, StringRef)> setNameFn) {
27   setNameFn(getResult(), "cst");
28 }
29 
isBuildableWith(Attribute value,Type type)30 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
31   if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
32     auto complexTy = type.dyn_cast<ComplexType>();
33     if (!complexTy)
34       return false;
35     auto complexEltTy = complexTy.getElementType();
36     return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
37            arrAttr[1].getType() == complexEltTy;
38   }
39   return false;
40 }
41 
verify()42 LogicalResult ConstantOp::verify() {
43   ArrayAttr arrayAttr = getValue();
44   if (arrayAttr.size() != 2) {
45     return emitOpError(
46         "requires 'value' to be a complex constant, represented as array of "
47         "two values");
48   }
49 
50   auto complexEltTy = getType().getElementType();
51   if (complexEltTy != arrayAttr[0].getType() ||
52       complexEltTy != arrayAttr[1].getType()) {
53     return emitOpError()
54            << "requires attribute's element types (" << arrayAttr[0].getType()
55            << ", " << arrayAttr[1].getType()
56            << ") to match the element type of the op's return type ("
57            << complexEltTy << ")";
58   }
59   return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // CreateOp
64 //===----------------------------------------------------------------------===//
65 
fold(ArrayRef<Attribute> operands)66 OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
67   assert(operands.size() == 2 && "binary op takes two operands");
68   // Fold complex.create(complex.re(op), complex.im(op)).
69   if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
70     if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
71       if (reOp.getOperand() == imOp.getOperand()) {
72         return reOp.getOperand();
73       }
74     }
75   }
76   return {};
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // ImOp
81 //===----------------------------------------------------------------------===//
82 
fold(ArrayRef<Attribute> operands)83 OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
84   assert(operands.size() == 1 && "unary op takes 1 operand");
85   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
86   if (arrayAttr && arrayAttr.size() == 2)
87     return arrayAttr[1];
88   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
89     return createOp.getOperand(1);
90   return {};
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // ReOp
95 //===----------------------------------------------------------------------===//
96 
fold(ArrayRef<Attribute> operands)97 OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
98   assert(operands.size() == 1 && "unary op takes 1 operand");
99   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
100   if (arrayAttr && arrayAttr.size() == 2)
101     return arrayAttr[0];
102   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
103     return createOp.getOperand(0);
104   return {};
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // AddOp
109 //===----------------------------------------------------------------------===//
110 
fold(ArrayRef<Attribute> operands)111 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
112   assert(operands.size() == 2 && "binary op takes 2 operands");
113 
114   // complex.add(complex.sub(a, b), b) -> a
115   if (auto sub = getLhs().getDefiningOp<SubOp>())
116     if (getRhs() == sub.getRhs())
117       return sub.getLhs();
118 
119   // complex.add(b, complex.sub(a, b)) -> a
120   if (auto sub = getRhs().getDefiningOp<SubOp>())
121     if (getLhs() == sub.getRhs())
122       return sub.getLhs();
123 
124   return {};
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // NegOp
129 //===----------------------------------------------------------------------===//
130 
fold(ArrayRef<Attribute> operands)131 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
132   assert(operands.size() == 1 && "unary op takes 1 operand");
133 
134   // complex.neg(complex.neg(a)) -> a
135   if (auto negOp = getOperand().getDefiningOp<NegOp>())
136     return negOp.getOperand();
137 
138   return {};
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // LogOp
143 //===----------------------------------------------------------------------===//
144 
fold(ArrayRef<Attribute> operands)145 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
146   assert(operands.size() == 1 && "unary op takes 1 operand");
147 
148   // complex.log(complex.exp(a)) -> a
149   if (auto expOp = getOperand().getDefiningOp<ExpOp>())
150     return expOp.getOperand();
151 
152   return {};
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // ExpOp
157 //===----------------------------------------------------------------------===//
158 
fold(ArrayRef<Attribute> operands)159 OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
160   assert(operands.size() == 1 && "unary op takes 1 operand");
161 
162   // complex.exp(complex.log(a)) -> a
163   if (auto logOp = getOperand().getDefiningOp<LogOp>())
164     return logOp.getOperand();
165 
166   return {};
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // TableGen'd op method definitions
171 //===----------------------------------------------------------------------===//
172 
173 #define GET_OP_CLASSES
174 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
175