1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Complex/IR/Complex.h"
10 #include "mlir/IR/Builders.h"
11 
12 using namespace mlir;
13 using namespace mlir::complex;
14 
15 //===----------------------------------------------------------------------===//
16 // ConstantOp
17 //===----------------------------------------------------------------------===//
18 
19 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
20   assert(operands.empty() && "constant has no operands");
21   return getValue();
22 }
23 
24 void ConstantOp::getAsmResultNames(
25     function_ref<void(Value, StringRef)> setNameFn) {
26   setNameFn(getResult(), "cst");
27 }
28 
29 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
30   if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
31     auto complexTy = type.dyn_cast<ComplexType>();
32     if (!complexTy)
33       return false;
34     auto complexEltTy = complexTy.getElementType();
35     return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
36            arrAttr[1].getType() == complexEltTy;
37   }
38   return false;
39 }
40 
41 LogicalResult ConstantOp::verify() {
42   ArrayAttr arrayAttr = getValue();
43   if (arrayAttr.size() != 2) {
44     return emitOpError(
45         "requires 'value' to be a complex constant, represented as array of "
46         "two values");
47   }
48 
49   auto complexEltTy = getType().getElementType();
50   if (complexEltTy != arrayAttr[0].getType() ||
51       complexEltTy != arrayAttr[1].getType()) {
52     return emitOpError()
53            << "requires attribute's element types (" << arrayAttr[0].getType()
54            << ", " << arrayAttr[1].getType()
55            << ") to match the element type of the op's return type ("
56            << complexEltTy << ")";
57   }
58   return success();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // CreateOp
63 //===----------------------------------------------------------------------===//
64 
65 OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
66   assert(operands.size() == 2 && "binary op takes two operands");
67   // Fold complex.create(complex.re(op), complex.im(op)).
68   if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
69     if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
70       if (reOp.getOperand() == imOp.getOperand()) {
71         return reOp.getOperand();
72       }
73     }
74   }
75   return {};
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // ImOp
80 //===----------------------------------------------------------------------===//
81 
82 OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
83   assert(operands.size() == 1 && "unary op takes 1 operand");
84   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
85   if (arrayAttr && arrayAttr.size() == 2)
86     return arrayAttr[1];
87   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
88     return createOp.getOperand(1);
89   return {};
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // ReOp
94 //===----------------------------------------------------------------------===//
95 
96 OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
97   assert(operands.size() == 1 && "unary op takes 1 operand");
98   ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
99   if (arrayAttr && arrayAttr.size() == 2)
100     return arrayAttr[0];
101   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
102     return createOp.getOperand(0);
103   return {};
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // TableGen'd op method definitions
108 //===----------------------------------------------------------------------===//
109 
110 #define GET_OP_CLASSES
111 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
112