//===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" using namespace mlir; using namespace mlir::arith; //===----------------------------------------------------------------------===// // Pattern helpers //===----------------------------------------------------------------------===// static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return builder.getIntegerAttr(res.getType(), lhs.cast().getInt() + rhs.cast().getInt()); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return builder.getIntegerAttr(res.getType(), lhs.cast().getInt() - rhs.cast().getInt()); } /// Invert an integer comparison predicate. static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) { switch (pred) { case arith::CmpIPredicate::eq: return arith::CmpIPredicate::ne; case arith::CmpIPredicate::ne: return arith::CmpIPredicate::eq; case arith::CmpIPredicate::slt: return arith::CmpIPredicate::sge; case arith::CmpIPredicate::sle: return arith::CmpIPredicate::sgt; case arith::CmpIPredicate::sgt: return arith::CmpIPredicate::sle; case arith::CmpIPredicate::sge: return arith::CmpIPredicate::slt; case arith::CmpIPredicate::ult: return arith::CmpIPredicate::uge; case arith::CmpIPredicate::ule: return arith::CmpIPredicate::ugt; case arith::CmpIPredicate::ugt: return arith::CmpIPredicate::ule; case arith::CmpIPredicate::uge: return arith::CmpIPredicate::ult; } llvm_unreachable("unknown cmpi predicate kind"); } static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { return arith::CmpIPredicateAttr::get(pred.getContext(), invertPredicate(pred.getValue())); } //===----------------------------------------------------------------------===// // TableGen'd canonicalization patterns //===----------------------------------------------------------------------===// namespace { #include "ArithmeticCanonicalization.inc" } // end anonymous namespace //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// void arith::ConstantOp::getAsmResultNames( function_ref setNameFn) { auto type = getType(); if (auto intCst = getValue().dyn_cast()) { auto intType = type.dyn_cast(); // Sugar i1 constants with 'true' and 'false'. if (intType && intType.getWidth() == 1) return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); // Otherwise, build a compex name with the value and type. SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << 'c' << intCst.getInt(); if (intType) specialName << '_' << type; setNameFn(getResult(), specialName.str()); } else { setNameFn(getResult(), "cst"); } } /// TODO: disallow arith.constant to return anything other than signless integer /// or float like. static LogicalResult verify(arith::ConstantOp op) { auto type = op.getType(); // The value's type must match the return type. if (op.getValue().getType() != type) { return op.emitOpError() << "value type " << op.getValue().getType() << " must match return type: " << type; } // Integer values must be signless. if (type.isa() && !type.cast().isSignless()) return op.emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. if (!op.getValue().isa()) { return op.emitOpError( "value must be an integer, float, or elements attribute"); } return success(); } bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { // The value's type must be the same as the provided type. if (value.getType() != type) return false; // Integer values must be signless. if (type.isa() && !type.cast().isSignless()) return false; // Integer, float, and element attributes are buildable. return value.isa(); } OpFoldResult arith::ConstantOp::fold(ArrayRef operands) { return getValue(); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width) { auto type = builder.getIntegerType(width); arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, Type type) { assert(type.isSignlessInteger() && "ConstantIntOp can only have signless integer type values"); arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } bool arith::ConstantIntOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isSignlessInteger(); return false; } void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type) { arith::ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value)); } bool arith::ConstantFloatOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isa(); return false; } void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, int64_t value) { arith::ConstantOp::build(builder, result, builder.getIndexType(), builder.getIndexAttr(value)); } bool arith::ConstantIndexOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isIndex(); return false; } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// OpFoldResult arith::AddIOp::fold(ArrayRef operands) { // addi(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a + b; }); } void arith::AddIOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert( context); } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// OpFoldResult arith::SubIOp::fold(ArrayRef operands) { // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); // subi(x,0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a - b; }); } void arith::SubIOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// OpFoldResult arith::MulIOp::fold(ArrayRef operands) { // muli(x, 0) -> 0 if (matchPattern(getRhs(), m_Zero())) return getRhs(); // muli(x, 1) -> x if (matchPattern(getRhs(), m_One())) return getOperand(0); // TODO: Handle the overflow case. // default folder return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a * b; }); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivUIOp::fold(ArrayRef operands) { // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (div0 || !b) { div0 = true; return a; } return a.udiv(b); }); // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return getLhs(); } return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivSIOp::fold(ArrayRef operands) { // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } return a.sdiv_ov(b, overflowOrDiv0); }); // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return getLhs(); } return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // Ceil and floor division folding helpers //===----------------------------------------------------------------------===// static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { // Returns (a-1)/b + 1 APInt one(a.getBitWidth(), 1, true); // Signed value 1. APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); return val.sadd_ov(one, overflow); } //===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::CeilDivSIOp::fold(ArrayRef operands) { // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } unsigned bits = a.getBitWidth(); APInt zero = APInt::getZero(bits); if (a.sgt(zero) && b.sgt(zero)) { // Both positive, return ceil(a, b). return signedCeilNonnegInputs(a, b, overflowOrDiv0); } if (a.slt(zero) && b.slt(zero)) { // Both negative, return ceil(-a, -b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt posB = zero.ssub_ov(b, overflowOrDiv0); return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); } if (a.slt(zero) && b.sgt(zero)) { // A is negative, b is positive, return - ( -a / b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt div = posA.sdiv_ov(b, overflowOrDiv0); return zero.ssub_ov(div, overflowOrDiv0); } // A is positive (or zero), b is negative, return - (a / -b). APInt posB = zero.ssub_ov(b, overflowOrDiv0); APInt div = a.sdiv_ov(posB, overflowOrDiv0); return zero.ssub_ov(div, overflowOrDiv0); }); // Fold out floor division by one. Assumes all tensors of all ones are // splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return getLhs(); } return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // FloorDivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::FloorDivSIOp::fold(ArrayRef operands) { // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } unsigned bits = a.getBitWidth(); APInt zero = APInt::getZero(bits); if (a.sge(zero) && b.sgt(zero)) { // Both positive (or a is zero), return a / b. return a.sdiv_ov(b, overflowOrDiv0); } if (a.sle(zero) && b.slt(zero)) { // Both negative (or a is zero), return -a / -b. APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt posB = zero.ssub_ov(b, overflowOrDiv0); return posA.sdiv_ov(posB, overflowOrDiv0); } if (a.slt(zero) && b.sgt(zero)) { // A is negative, b is positive, return - ceil(-a, b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); return zero.ssub_ov(ceil, overflowOrDiv0); } // A is positive, b is negative, return - ceil(a, -b). APInt posB = zero.ssub_ov(b, overflowOrDiv0); APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); return zero.ssub_ov(ceil, overflowOrDiv0); }); // Fold out floor division by one. Assumes all tensors of all ones are // splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return getLhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return getLhs(); } return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // RemUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemUIOp::fold(ArrayRef operands) { auto rhs = operands.back().dyn_cast_or_null(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); } //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemSIOp::fold(ArrayRef operands) { auto rhs = operands.back().dyn_cast_or_null(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// OpFoldResult arith::AndIOp::fold(ArrayRef operands) { /// and(x, 0) -> 0 if (matchPattern(getRhs(), m_Zero())) return getRhs(); /// and(x, allOnes) -> x APInt intValue; if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) return getLhs(); /// and(x, x) -> x if (getLhs() == getRhs()) return getRhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a & b; }); } //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// OpFoldResult arith::OrIOp::fold(ArrayRef operands) { /// or(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); /// or(x, x) -> x if (getLhs() == getRhs()) return getRhs(); /// or(x, ) -> if (auto rhsAttr = operands[1].dyn_cast_or_null()) if (rhsAttr.getValue().isAllOnes()) return rhsAttr; return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a | b; }); } //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// OpFoldResult arith::XOrIOp::fold(ArrayRef operands) { /// xor(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); /// xor(x, x) -> 0 if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a ^ b; }); } void arith::XOrIOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// OpFoldResult arith::AddFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a + b; }); } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// OpFoldResult arith::SubFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a - b; }); } //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MulFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a * b; }); } //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a / b; }); } //===----------------------------------------------------------------------===// // Utility functions for verifying cast ops //===----------------------------------------------------------------------===// template using type_list = std::tuple *; /// Returns a non-null type only if the provided type is one of the allowed /// types or one of the allowed shaped types of the allowed types. Returns the /// element type if a valid shaped type is provided. template static Type getUnderlyingType(Type type, type_list, type_list) { if (type.isa() && !type.isa()) return {}; auto underlyingType = getElementTypeOrSelf(type); if (!underlyingType.isa()) return {}; return underlyingType; } /// Get allowed underlying types for vectors and tensors. template static Type getTypeIfLike(Type type) { return getUnderlyingType(type, type_list(), type_list()); } /// Get allowed underlying types for vectors, tensors, and memrefs. template static Type getTypeIfLikeOrMemRef(Type type) { return getUnderlyingType(type, type_list(), type_list()); } static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { return inputs.size() == 1 && outputs.size() == 1 && succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); } //===----------------------------------------------------------------------===// // Verifiers for integer and floating point extension/truncation ops //===----------------------------------------------------------------------===// // Extend ops can only extend to a wider type. template static LogicalResult verifyExtOp(Op op) { Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); if (srcType.cast().getWidth() >= dstType.cast().getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } // Truncate ops can only truncate to a shorter type. template static LogicalResult verifyTruncateOp(Op op) { Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); if (srcType.cast().getWidth() <= dstType.cast().getWidth()) return op.emitError("result type ") << dstType << " must be shorter than operand type " << srcType; return success(); } /// Validate a cast that changes the width of a type. template