1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Complex/IR/Complex.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/Matchers.h"
12
13 using namespace mlir;
14 using namespace mlir::complex;
15
16 //===----------------------------------------------------------------------===//
17 // ConstantOp
18 //===----------------------------------------------------------------------===//
19
fold(ArrayRef<Attribute> operands)20 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
21 assert(operands.empty() && "constant has no operands");
22 return getValue();
23 }
24
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)25 void ConstantOp::getAsmResultNames(
26 function_ref<void(Value, StringRef)> setNameFn) {
27 setNameFn(getResult(), "cst");
28 }
29
isBuildableWith(Attribute value,Type type)30 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
31 if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
32 auto complexTy = type.dyn_cast<ComplexType>();
33 if (!complexTy)
34 return false;
35 auto complexEltTy = complexTy.getElementType();
36 return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
37 arrAttr[1].getType() == complexEltTy;
38 }
39 return false;
40 }
41
verify()42 LogicalResult ConstantOp::verify() {
43 ArrayAttr arrayAttr = getValue();
44 if (arrayAttr.size() != 2) {
45 return emitOpError(
46 "requires 'value' to be a complex constant, represented as array of "
47 "two values");
48 }
49
50 auto complexEltTy = getType().getElementType();
51 if (complexEltTy != arrayAttr[0].getType() ||
52 complexEltTy != arrayAttr[1].getType()) {
53 return emitOpError()
54 << "requires attribute's element types (" << arrayAttr[0].getType()
55 << ", " << arrayAttr[1].getType()
56 << ") to match the element type of the op's return type ("
57 << complexEltTy << ")";
58 }
59 return success();
60 }
61
62 //===----------------------------------------------------------------------===//
63 // CreateOp
64 //===----------------------------------------------------------------------===//
65
fold(ArrayRef<Attribute> operands)66 OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
67 assert(operands.size() == 2 && "binary op takes two operands");
68 // Fold complex.create(complex.re(op), complex.im(op)).
69 if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
70 if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
71 if (reOp.getOperand() == imOp.getOperand()) {
72 return reOp.getOperand();
73 }
74 }
75 }
76 return {};
77 }
78
79 //===----------------------------------------------------------------------===//
80 // ImOp
81 //===----------------------------------------------------------------------===//
82
fold(ArrayRef<Attribute> operands)83 OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
84 assert(operands.size() == 1 && "unary op takes 1 operand");
85 ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
86 if (arrayAttr && arrayAttr.size() == 2)
87 return arrayAttr[1];
88 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
89 return createOp.getOperand(1);
90 return {};
91 }
92
93 //===----------------------------------------------------------------------===//
94 // ReOp
95 //===----------------------------------------------------------------------===//
96
fold(ArrayRef<Attribute> operands)97 OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
98 assert(operands.size() == 1 && "unary op takes 1 operand");
99 ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
100 if (arrayAttr && arrayAttr.size() == 2)
101 return arrayAttr[0];
102 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
103 return createOp.getOperand(0);
104 return {};
105 }
106
107 //===----------------------------------------------------------------------===//
108 // AddOp
109 //===----------------------------------------------------------------------===//
110
fold(ArrayRef<Attribute> operands)111 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
112 assert(operands.size() == 2 && "binary op takes 2 operands");
113
114 // complex.add(complex.sub(a, b), b) -> a
115 if (auto sub = getLhs().getDefiningOp<SubOp>())
116 if (getRhs() == sub.getRhs())
117 return sub.getLhs();
118
119 // complex.add(b, complex.sub(a, b)) -> a
120 if (auto sub = getRhs().getDefiningOp<SubOp>())
121 if (getLhs() == sub.getRhs())
122 return sub.getLhs();
123
124 return {};
125 }
126
127 //===----------------------------------------------------------------------===//
128 // NegOp
129 //===----------------------------------------------------------------------===//
130
fold(ArrayRef<Attribute> operands)131 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
132 assert(operands.size() == 1 && "unary op takes 1 operand");
133
134 // complex.neg(complex.neg(a)) -> a
135 if (auto negOp = getOperand().getDefiningOp<NegOp>())
136 return negOp.getOperand();
137
138 return {};
139 }
140
141 //===----------------------------------------------------------------------===//
142 // LogOp
143 //===----------------------------------------------------------------------===//
144
fold(ArrayRef<Attribute> operands)145 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) {
146 assert(operands.size() == 1 && "unary op takes 1 operand");
147
148 // complex.log(complex.exp(a)) -> a
149 if (auto expOp = getOperand().getDefiningOp<ExpOp>())
150 return expOp.getOperand();
151
152 return {};
153 }
154
155 //===----------------------------------------------------------------------===//
156 // ExpOp
157 //===----------------------------------------------------------------------===//
158
fold(ArrayRef<Attribute> operands)159 OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) {
160 assert(operands.size() == 1 && "unary op takes 1 operand");
161
162 // complex.exp(complex.log(a)) -> a
163 if (auto logOp = getOperand().getDefiningOp<LogOp>())
164 return logOp.getOperand();
165
166 return {};
167 }
168
169 //===----------------------------------------------------------------------===//
170 // TableGen'd op method definitions
171 //===----------------------------------------------------------------------===//
172
173 #define GET_OP_CLASSES
174 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
175