1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Complex/IR/Complex.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/Matchers.h" 12 13 using namespace mlir; 14 using namespace mlir::complex; 15 16 //===----------------------------------------------------------------------===// 17 // ConstantOp 18 //===----------------------------------------------------------------------===// 19 20 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { 21 assert(operands.empty() && "constant has no operands"); 22 return getValue(); 23 } 24 25 void ConstantOp::getAsmResultNames( 26 function_ref<void(Value, StringRef)> setNameFn) { 27 setNameFn(getResult(), "cst"); 28 } 29 30 bool ConstantOp::isBuildableWith(Attribute value, Type type) { 31 if (auto arrAttr = value.dyn_cast<ArrayAttr>()) { 32 auto complexTy = type.dyn_cast<ComplexType>(); 33 if (!complexTy) 34 return false; 35 auto complexEltTy = complexTy.getElementType(); 36 return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy && 37 arrAttr[1].getType() == complexEltTy; 38 } 39 return false; 40 } 41 42 LogicalResult ConstantOp::verify() { 43 ArrayAttr arrayAttr = getValue(); 44 if (arrayAttr.size() != 2) { 45 return emitOpError( 46 "requires 'value' to be a complex constant, represented as array of " 47 "two values"); 48 } 49 50 auto complexEltTy = getType().getElementType(); 51 if (complexEltTy != arrayAttr[0].getType() || 52 complexEltTy != arrayAttr[1].getType()) { 53 return emitOpError() 54 << "requires attribute's element types (" << arrayAttr[0].getType() 55 << ", " << arrayAttr[1].getType() 56 << ") to match the element type of the op's return type (" 57 << complexEltTy << ")"; 58 } 59 return success(); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // CreateOp 64 //===----------------------------------------------------------------------===// 65 66 OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) { 67 assert(operands.size() == 2 && "binary op takes two operands"); 68 // Fold complex.create(complex.re(op), complex.im(op)). 69 if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) { 70 if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) { 71 if (reOp.getOperand() == imOp.getOperand()) { 72 return reOp.getOperand(); 73 } 74 } 75 } 76 return {}; 77 } 78 79 //===----------------------------------------------------------------------===// 80 // ImOp 81 //===----------------------------------------------------------------------===// 82 83 OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) { 84 assert(operands.size() == 1 && "unary op takes 1 operand"); 85 ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>(); 86 if (arrayAttr && arrayAttr.size() == 2) 87 return arrayAttr[1]; 88 if (auto createOp = getOperand().getDefiningOp<CreateOp>()) 89 return createOp.getOperand(1); 90 return {}; 91 } 92 93 //===----------------------------------------------------------------------===// 94 // ReOp 95 //===----------------------------------------------------------------------===// 96 97 OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) { 98 assert(operands.size() == 1 && "unary op takes 1 operand"); 99 ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>(); 100 if (arrayAttr && arrayAttr.size() == 2) 101 return arrayAttr[0]; 102 if (auto createOp = getOperand().getDefiningOp<CreateOp>()) 103 return createOp.getOperand(0); 104 return {}; 105 } 106 107 //===----------------------------------------------------------------------===// 108 // AddOp 109 //===----------------------------------------------------------------------===// 110 111 OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) { 112 assert(operands.size() == 2 && "binary op takes 2 operands"); 113 114 // complex.add(complex.sub(a, b), b) -> a 115 if (auto sub = getLhs().getDefiningOp<SubOp>()) 116 if (getRhs() == sub.getRhs()) 117 return sub.getLhs(); 118 119 // complex.add(b, complex.sub(a, b)) -> a 120 if (auto sub = getRhs().getDefiningOp<SubOp>()) 121 if (getLhs() == sub.getRhs()) 122 return sub.getLhs(); 123 124 return {}; 125 } 126 127 //===----------------------------------------------------------------------===// 128 // NegOp 129 //===----------------------------------------------------------------------===// 130 131 OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) { 132 assert(operands.size() == 1 && "unary op takes 1 operand"); 133 134 // complex.neg(complex.neg(a)) -> a 135 if (auto negOp = getOperand().getDefiningOp<NegOp>()) 136 return negOp.getOperand(); 137 138 return {}; 139 } 140 141 //===----------------------------------------------------------------------===// 142 // LogOp 143 //===----------------------------------------------------------------------===// 144 145 OpFoldResult LogOp::fold(ArrayRef<Attribute> operands) { 146 assert(operands.size() == 1 && "unary op takes 1 operand"); 147 148 // complex.log(complex.exp(a)) -> a 149 if (auto expOp = getOperand().getDefiningOp<ExpOp>()) 150 return expOp.getOperand(); 151 152 return {}; 153 } 154 155 //===----------------------------------------------------------------------===// 156 // ExpOp 157 //===----------------------------------------------------------------------===// 158 159 OpFoldResult ExpOp::fold(ArrayRef<Attribute> operands) { 160 assert(operands.size() == 1 && "unary op takes 1 operand"); 161 162 // complex.exp(complex.log(a)) -> a 163 if (auto logOp = getOperand().getDefiningOp<LogOp>()) 164 return logOp.getOperand(); 165 166 return {}; 167 } 168 169 //===----------------------------------------------------------------------===// 170 // TableGen'd op method definitions 171 //===----------------------------------------------------------------------===// 172 173 #define GET_OP_CLASSES 174 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" 175