1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Complex/IR/Complex.h" 10 #include "mlir/IR/Builders.h" 11 12 using namespace mlir; 13 using namespace mlir::complex; 14 15 //===----------------------------------------------------------------------===// 16 // ConstantOp 17 //===----------------------------------------------------------------------===// 18 19 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { 20 assert(operands.empty() && "constant has no operands"); 21 return getValue(); 22 } 23 24 void ConstantOp::getAsmResultNames( 25 function_ref<void(Value, StringRef)> setNameFn) { 26 setNameFn(getResult(), "cst"); 27 } 28 29 bool ConstantOp::isBuildableWith(Attribute value, Type type) { 30 if (auto arrAttr = value.dyn_cast<ArrayAttr>()) { 31 auto complexTy = type.dyn_cast<ComplexType>(); 32 if (!complexTy) 33 return false; 34 auto complexEltTy = complexTy.getElementType(); 35 return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy && 36 arrAttr[1].getType() == complexEltTy; 37 } 38 return false; 39 } 40 41 LogicalResult ConstantOp::verify() { 42 ArrayAttr arrayAttr = getValue(); 43 if (arrayAttr.size() != 2) { 44 return emitOpError( 45 "requires 'value' to be a complex constant, represented as array of " 46 "two values"); 47 } 48 49 auto complexEltTy = getType().getElementType(); 50 if (complexEltTy != arrayAttr[0].getType() || 51 complexEltTy != arrayAttr[1].getType()) { 52 return emitOpError() 53 << "requires attribute's element types (" << arrayAttr[0].getType() 54 << ", " << arrayAttr[1].getType() 55 << ") to match the element type of the op's return type (" 56 << complexEltTy << ")"; 57 } 58 return success(); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // CreateOp 63 //===----------------------------------------------------------------------===// 64 65 OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) { 66 assert(operands.size() == 2 && "binary op takes two operands"); 67 // Fold complex.create(complex.re(op), complex.im(op)). 68 if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) { 69 if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) { 70 if (reOp.getOperand() == imOp.getOperand()) { 71 return reOp.getOperand(); 72 } 73 } 74 } 75 return {}; 76 } 77 78 //===----------------------------------------------------------------------===// 79 // ImOp 80 //===----------------------------------------------------------------------===// 81 82 OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) { 83 assert(operands.size() == 1 && "unary op takes 1 operand"); 84 ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>(); 85 if (arrayAttr && arrayAttr.size() == 2) 86 return arrayAttr[1]; 87 if (auto createOp = getOperand().getDefiningOp<CreateOp>()) 88 return createOp.getOperand(1); 89 return {}; 90 } 91 92 //===----------------------------------------------------------------------===// 93 // ReOp 94 //===----------------------------------------------------------------------===// 95 96 OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) { 97 assert(operands.size() == 1 && "unary op takes 1 operand"); 98 ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>(); 99 if (arrayAttr && arrayAttr.size() == 2) 100 return arrayAttr[0]; 101 if (auto createOp = getOperand().getDefiningOp<CreateOp>()) 102 return createOp.getOperand(0); 103 return {}; 104 } 105 106 //===----------------------------------------------------------------------===// 107 // TableGen'd op method definitions 108 //===----------------------------------------------------------------------===// 109 110 #define GET_OP_CLASSES 111 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" 112